From 6151592ea795a4bf5950a16b917b89a21e4131c4 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Thu, 21 Aug 2025 12:38:04 -0400 Subject: [PATCH 01/26] constants and tensor mappings for modern bert support, model not supported yet but working on getting conversion to work for encoder only --- convert_hf_to_gguf.py | 36 ++++++++++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 14 +++++++++++++ gguf-py/gguf/tensor_mapping.py | 11 ++++++++++- src/llama-arch.h | 1 + 4 files changed, 61 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b8c7d97a786c7..6251529e54d88 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -133,6 +133,7 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: self.ftype = gguf.LlamaFileType.MOSTLY_BF16 # Configure GGUF Writer + print(f"arch: {gguf.MODEL_ARCH_NAMES[self.model_arch]}") self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @@ -465,6 +466,7 @@ def print_registered_models(cls): @classmethod def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type[ModelBase]: try: + print(f"model_type: {model_type}, arch: {arch}") return cls._model_classes[model_type][arch] except KeyError: raise NotImplementedError(f'Architecture {arch!r} not supported!') from None @@ -8303,6 +8305,40 @@ def prepare_tensors(self): experts = [k for d in self._experts for k in d.keys()] if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("ModernBertModel") +class ModernBertModel(TextModel): + model_arch = gguf.MODEL_ARCH.MODERN_BERT + + def set_gguf_parameters(self) -> None: + # Determine block count (number of hidden layers) + block_count = self.hparams.get("num_hidden_layers") or self.hparams.get("num_hidden_layers_alt") + if block_count is None: + raise ValueError("Could not determine number of hidden layers from hparams") + + # Attention heads and dimensions + n_head = self.hparams.get("num_attention_heads") + if n_head is None: + raise ValueError("Missing 'num_attention_heads' in hparams") + + hidden_size = self.hparams["hidden_size"] + head_dim = hidden_size // n_head + ffn_dim = self.hparams.get("intermediate_size", 4 * hidden_size) + + # GGUF parameter assignment + self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 512)) + self.gguf_writer.add_embedding_length(hidden_size) + self.gguf_writer.add_feed_forward_length(ffn_dim) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_layer_norm_eps(self.hparams.get("layer_norm_eps", 1e-12)) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Directly map tensor names without QKV splitting or reordering + return [(self.map_tensor_name(name), data_torch)] + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 911eea504a19e..1273ca31d5830 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -311,6 +311,7 @@ class MODEL_ARCH(IntEnum): STARCODER = auto() REFACT = auto() BERT = auto() + MODERN_BERT = auto() NOMIC_BERT = auto() NOMIC_BERT_MOE = auto() NEO_BERT = auto() @@ -642,6 +643,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.STARCODER: "starcoder", MODEL_ARCH.REFACT: "refact", MODEL_ARCH.BERT: "bert", + MODEL_ARCH.MODERN_BERT: "modern-bert", MODEL_ARCH.NOMIC_BERT: "nomic-bert", MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", MODEL_ARCH.NEO_BERT: "neo-bert", @@ -1172,6 +1174,18 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.CLS, MODEL_TENSOR.CLS_OUT, ], + MODEL_ARCH.MODERN_BERT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_NORM, + ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index dc7c03b464c25..2d3c16ab8496c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -16,6 +16,7 @@ class TensorNameMap: "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert + "embeddings.tok_embeddings", # modern bert "language_model.embedding.word_embeddings", # persimmon "wte", # gpt2 "transformer.embd.wte", # phi2 @@ -45,6 +46,7 @@ class TensorNameMap: MODEL_TENSOR.TOKEN_EMBD_NORM: ( "word_embeddings_layernorm", # bloom "embeddings.LayerNorm", # bert + "embeddings.norm", # modern bert "emb_ln", # nomic-bert "transformer.norm", # openelm "rwkv.blocks.0.pre_ln", # rwkv @@ -98,6 +100,7 @@ class TensorNameMap: "backbone.final_layer_norm", # wavtokenizer "model.norm", # llama4 "model.transformer.ln_f", # llada + "final_norm", # modern bert ), # Rope frequencies @@ -142,9 +145,10 @@ class TensorNameMap: "model.layers.{bid}.ln1", # rwkv7 "model.layers.{bid}.input_layernorm", # llama4 "transformer_encoder.{bid}.attention_norm", # neobert + "layers.{bid}.attn_norm", # bert "model.layers.{bid}.operator_norm", # lfm2 "model.transformer.blocks.{bid}.attn_norm", # llada - "layers.{bid}.input_layernorm", # qwen3-embedding + "layers.{bid}.input_layernorm", # qwen3-embedding, ), # Attention norm 2 @@ -174,6 +178,7 @@ class TensorNameMap: "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm "transformer_encoder.{bid}.qkv", # neobert + "layers.{bid}.attn.Wqkv", # modern bert ), # Attention query @@ -240,6 +245,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert + "layers.{bid}.attn.Wo", # modern bert "transformer.layer.{bid}.attention.out_lin", # distillbert "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon @@ -311,6 +317,7 @@ class TensorNameMap: "model.layers.layers.{bid}.pre_mlp_norm", # plamo2 "model.transformer.blocks.{bid}.ff_norm", # llada "layers.{bid}.post_attention_layernorm", # qwen3-embedding + "layers.{bid}.mlp_norm" # modern bert ), # Post feed-forward norm @@ -360,6 +367,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2 "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert + "layers.{bid}.mlp.Wo", # modern bert "transformer.layer.{bid}.ffn.lin1", # distillbert "transformer.h.{bid}.mlp.fc_in", # gpt-j "transformer.h.{bid}.mlp.linear_3", # refact @@ -459,6 +467,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2 "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert + "layers.{bid}.mlp.Wi", # modern bert "transformer.layer.{bid}.ffn.lin2", # distillbert "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon diff --git a/src/llama-arch.h b/src/llama-arch.h index 7af587e7951bc..c99448e78f481 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -22,6 +22,7 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_REFACT, LLM_ARCH_BERT, + LLM_ARCH_MODERN_BERT, LLM_ARCH_NOMIC_BERT, LLM_ARCH_NOMIC_BERT_MOE, LLM_ARCH_NEO_BERT, From 6643c5a852a939bfc8d9750b5c1eee9765c1a4c1 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Thu, 21 Aug 2025 12:42:32 -0400 Subject: [PATCH 02/26] conversion now working, hf -> gguf --- convert_hf_to_gguf.py | 3 +++ convert_hf_to_gguf_update.py | 1 + 2 files changed, 4 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6251529e54d88..6ed2587f8a9b5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -860,6 +860,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756": # ref: https://huggingface.co/JetBrains/Mellum-4b-base res = "mellum" + if chkhsh == "a0b64b4385f123663873756336c085744376d015ff328bb1d901598f63c44152": + # ref: https://huggingface.co/ibm-granite/granite-embedding-small-english-r2 + res = "modern-bert" if res is None: logger.warning("\n") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 575e05e193c2e..6eeb1f64e3bed 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"}, {"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", }, {"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", }, + {"name": "modern-bert", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-embedding-small-english-r2", }, ] # some models are known to be broken upstream, so we will skip them as exceptions From ac67fc68871978c8e92ee98eeb60ebf36ca4c858 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Mon, 25 Aug 2025 16:15:40 -0400 Subject: [PATCH 03/26] working on support, now working on building graph --- src/llama-arch.cpp | 16 +++ src/llama-graph.cpp | 4 +- src/llama-model.cpp | 251 ++++++++++++++++++++++++++++++++++++++++++-- src/llama-model.h | 1 + src/llama-vocab.cpp | 5 +- src/llama.cpp | 1 + 6 files changed, 265 insertions(+), 13 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 18dcc6ddfe567..031b4c486f609 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -18,6 +18,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_MODERN_BERT, "modern-bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_NEO_BERT, "neo-bert" }, @@ -505,6 +506,21 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CLS_OUT, "cls.output" }, }, }, + { + LLM_ARCH_MODERN_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + }, + }, { LLM_ARCH_NOMIC_BERT, { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 053c72d6dc8d1..417552096f494 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1375,7 +1375,9 @@ ggml_tensor * llm_graph_context::build_attn( // [TAG_NO_CACHE_PAD] // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams - assert(!ubatch.equal_seqs()); + LLAMA_LOG_INFO("ubatch.equal_seqs() = %d, n_seqs = %d\n", ubatch.equal_seqs(), ubatch.n_seqs); + + //assert(!ubatch.equal_seqs()); ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 58ca7df707ef3..67fc2d003c6ee 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -451,6 +451,7 @@ void llama_model::load_arch(llama_model_loader & ml) { } void llama_model::load_hparams(llama_model_loader & ml) { + const gguf_context * ctx = ml.meta.get(); // get metadata as string @@ -464,6 +465,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { gguf_kv.emplace(name, value); } + // get general kv ml.get_key(LLM_KV_GENERAL_NAME, name, false); @@ -584,6 +586,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } // arch-specific KVs + LLAMA_LOG_INFO("Switching Arch\n"); switch (arch) { case LLM_ARCH_LLAMA: { @@ -757,6 +760,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MODERN_BERT: + { + //ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + LLAMA_LOG_INFO("Switching Modern Bert Arch\n"); + switch (hparams.n_layer) { + case 12: + type = LLM_TYPE_47M; break; // granite-embeddings-mall + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1888,7 +1901,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { void llama_model::load_vocab(llama_model_loader & ml) { const auto kv = LLM_KV(arch); - vocab.load(ml, kv); } @@ -2022,6 +2034,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_ctx_train = hparams.n_ctx_train; + LLAMA_LOG_INFO("n_head = %lld\n", (long long) n_head); + LLAMA_LOG_INFO("n_head_kv = %lld\n", (long long) n_head_kv); + LLAMA_LOG_INFO("n_embd = %lld\n", (long long) n_embd); + LLAMA_LOG_INFO("n_embd_k_gqa = %lld\n", (long long) n_embd_k_gqa); + LLAMA_LOG_INFO("n_embd_v_gqa = %lld\n", (long long) n_embd_v_gqa); + LLAMA_LOG_INFO("n_embd_head_k = %lld\n", (long long) n_embd_head_k); + LLAMA_LOG_INFO("n_embd_head_v = %lld\n", (long long) n_embd_head_v); + LLAMA_LOG_INFO("n_ff = %lld\n", (long long) n_ff); + LLAMA_LOG_INFO("n_embd_gqa = %lld\n", (long long) n_embd_gqa); + LLAMA_LOG_INFO("n_vocab = %lld\n", (long long) n_vocab); + LLAMA_LOG_INFO("n_token_types = %lld\n", (long long) n_token_types); + LLAMA_LOG_INFO("n_rot = %lld\n", (long long) n_rot); + LLAMA_LOG_INFO("n_expert = %lld\n", (long long) n_expert); + LLAMA_LOG_INFO("n_expert_used = %lld\n", (long long) n_expert_used); + LLAMA_LOG_INFO("n_ctx_train = %lld\n", (long long) n_ctx_train); + if (n_expert > 0 && hparams.n_expert_used == 0) { throw std::runtime_error("model has expert layers but no expert layers are used"); } @@ -2033,7 +2061,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); - + LLAMA_LOG_INFO("Creating Tensor: %s\n", tn.str().c_str()); if (!t_meta) { if (flags & TENSOR_NOT_REQUIRED) { return nullptr; @@ -2108,7 +2136,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } ggml_backend_buffer_type_t buft = nullptr; - // check overrides if (ml.tensor_buft_overrides) { std::string tensor_name = tn.str(); @@ -2156,7 +2183,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { first_moved_to_buft = buft; } } - ggml_context * ctx = ctx_for_buft(buft); // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one @@ -2614,11 +2640,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + if (arch == LLM_ARCH_BERT) { pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); @@ -2626,14 +2655,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); } - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); - for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + if (!layer.wqkv) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); @@ -2647,7 +2673,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - + + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -2657,6 +2684,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); @@ -2671,6 +2699,33 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + + } + } break; + case LLM_ARCH_MODERN_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for(int i = 0; i < n_layer; ++i) { + auto& layer = layers[i]; + + // layer 0 uses identity so we dont need weights for said layer + if ( i != 0 ) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + } + else{ + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_ff, n_embd} , 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_ff * 2}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); } } break; case LLM_ARCH_NEO_BERT: @@ -7498,6 +7553,175 @@ struct llm_build_bert : public llm_graph_context { } }; +struct llm_build_modern_bert : public llm_graph_context { + llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_head = hparams.n_head(); + const int64_t n_head_kv = hparams.n_head_kv(); + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); // == n_head_kv * n_embd_head + const int64_t n_tokens = ubatch.n_tokens; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // RoPE params + const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; // ModernBERT uses rotary + const int32_t n_rot = hparams.n_rot; + const int32_t n_ctx_orig = hparams.n_ctx_train; + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = nullptr; + + // ModernBERT needs positions for RoPE + inp_pos = build_inp_pos(); + + // 1) embeddings (token + optional type), NO absolute pos embed + inpL = build_inp_embd(model.tok_embd); + + if (model.type_embd) { + ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); + inpL = ggml_add(ctx0, inpL, type_row0); + } + cb(inpL, "inp_embd", -1); + + // 2) embeddings LayerNorm (embeddings.norm) + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + cb(inpL, "inp_norm", -1); + + auto * inp_attn = build_attn_inp_no_cache(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * x = inpL; + + // pre-attention norm (attn_norm). Layer 0 may be Identity() -> nullptr + ggml_tensor * x_attn_in = x; + if (model.layers[il].attn_norm) { + x_attn_in = build_norm(x, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(x_attn_in, "attn_pre_norm", il); + } else { + cb(x_attn_in, "attn_pre_norm_identity", il); + } + + // Attention: fused Wqkv -> split -> heads -> RoPE(Q,K) -> attn -> Wo + ggml_tensor * qkv = nullptr; + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + GGML_ASSERT(model.layers[il].wqkv); // ModernBERT uses fused QKV + qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in); + cb(qkv, "wqkv", il); + + if (model.layers[il].bqkv) { + qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv); + cb(qkv, "bqkv", il); + } + + // Fused layout: [ (n_embd + 2*n_embd_gqa), n_tokens ] + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + // Optional per Q/K + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); + } + + // Heads + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // RoPE (NEOX) on Q and K + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + cb(Vcur, "Vcur", il); + + ggml_tensor * attn_out = build_attn( + inp_attn, + model.layers[il].wo, model.layers[il].bo, // Wo, optional bias + Qcur, Kcur, Vcur, + /*K_cache*/ nullptr, + /*V_cache*/ nullptr, + 1.0f / sqrtf(float(n_embd_head)), + il); + cb(attn_out, "attn_out", il); + + // Residual after attention + ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x); + + // If we subselect outputs, do it at the last layer after attn resid + if (il == n_layer - 1 && inp_out_ids) { + cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids); + x = ggml_get_rows(ctx0, x, inp_out_ids); + } + + // 5) pre-MLP norm (mlp_norm) + ggml_tensor * h = build_norm(cur_attn, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(h, "mlp_pre_norm", il); + + // 6) MLP (prefer GEGLU if gate exists or up has 2*n_ff rows) + ggml_tensor * mlp_out = nullptr; + const bool has_gate_tensor = (model.layers[il].ffn_gate != nullptr); + const bool up_is_2x = (model.layers[il].ffn_up && model.layers[il].ffn_up->ne[0] == 2*hparams.n_ff()); + + if (has_gate_tensor || up_is_2x) { + mlp_out = build_ffn( + h, + model.layers[il].ffn_up, /*up_b*/ nullptr, /*up_shexp*/ nullptr, + model.layers[il].ffn_gate, /*gate_b*/ nullptr, /*gate_shexp*/ nullptr, + model.layers[il].ffn_down, /*down_b*/ nullptr, /*down_shexp*/ nullptr, + /*expert_scores*/ nullptr, + LLM_FFN_GEGLU, LLM_FFN_PAR, il); + cb(mlp_out, "ffn_out_geglu", il); + } else { + mlp_out = build_ffn( + h, + model.layers[il].ffn_up, /*up_b*/ nullptr, /*up_shexp*/ nullptr, + /*gate*/ nullptr, /*gate_b*/ nullptr, /*gate_shexp*/ nullptr, + model.layers[il].ffn_down, /*down_b*/ nullptr, /*down_shexp*/ nullptr, + /*expert_scores*/ nullptr, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(mlp_out, "ffn_out_gelu", il); + } + + // 7) Residual after MLP + ggml_tensor * cur_layer = ggml_add(ctx0, mlp_out, cur_attn); + + // 8) feed into next layer + inpL = cur_layer; + } + + // 9) final model norm (final_norm) + cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); + cb(cur, "final_norm", -1); + + res->t_embd = cur; + ggml_build_forward_expand(gf, cur); + } +}; + + struct llm_build_neo_bert : public llm_graph_context { llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -18186,6 +18410,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MODERN_BERT: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_NEO_BERT: { llm = std::make_unique(*this, params); @@ -18666,6 +18894,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GROK: case LLM_ARCH_DBRX: case LLM_ARCH_BERT: + case LLM_ARCH_MODERN_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_STABLELM: diff --git a/src/llama-model.h b/src/llama-model.h index 6fcd74d57fdca..5ebe320e36587 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -23,6 +23,7 @@ enum llm_type { LLM_TYPE_17M, LLM_TYPE_22M, LLM_TYPE_33M, + LLM_TYPE_47M, LLM_TYPE_60M, LLM_TYPE_70M, LLM_TYPE_80M, diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index f7e03e702ed19..92a21b6426f3f 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1661,10 +1661,13 @@ struct llama_vocab::impl { void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { struct gguf_context * ctx = ml.meta.get(); + LLAMA_LOG_INFO("Determining Vocab Type\n"); // determine vocab type { ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + LLAMA_LOG_INFO("pre tokenizer model: %s\n", tokenizer_pre.c_str()); + LLAMA_LOG_INFO("tokenizer model: %s\n", tokenizer_model.c_str()); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false); @@ -1813,7 +1816,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { LLAMA_LOG_WARN("%s: ************************************ \n", __func__); LLAMA_LOG_WARN("%s: \n", __func__); pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } else if (tokenizer_pre == "default") { + } else if (tokenizer_pre == "default" || tokenizer_pre == "modern-bert") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( tokenizer_pre == "llama3" || diff --git a/src/llama.cpp b/src/llama.cpp index 34906cdb62844..024e142453768 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -126,6 +126,7 @@ static int llama_model_load(const std::string & fname, std::vector if (!model.load_tensors(ml)) { return -2; } + } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); return -1; From cc40378d27faddf7a2d791cf0df53c35ad9ff6b5 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Mon, 25 Aug 2025 16:31:08 -0400 Subject: [PATCH 04/26] some cleanup --- src/llama-graph.cpp | 2 +- src/llama-model.cpp | 14 +++++++------- src/llama-vocab.cpp | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 417552096f494..a05de6e585831 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1377,7 +1377,7 @@ ggml_tensor * llm_graph_context::build_attn( // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams LLAMA_LOG_INFO("ubatch.equal_seqs() = %d, n_seqs = %d\n", ubatch.equal_seqs(), ubatch.n_seqs); - //assert(!ubatch.equal_seqs()); + assert(!ubatch.equal_seqs()); ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 67fc2d003c6ee..e31f5e51593da 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7578,7 +7578,7 @@ struct llm_build_modern_bert : public llm_graph_context { // ModernBERT needs positions for RoPE inp_pos = build_inp_pos(); - // 1) embeddings (token + optional type), NO absolute pos embed + // embeddings (token + optional type), NO absolute pos embed inpL = build_inp_embd(model.tok_embd); if (model.type_embd) { @@ -7587,7 +7587,7 @@ struct llm_build_modern_bert : public llm_graph_context { } cb(inpL, "inp_embd", -1); - // 2) embeddings LayerNorm (embeddings.norm) + // embeddings LayerNorm (embeddings.norm) inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); cb(inpL, "inp_norm", -1); @@ -7673,14 +7673,14 @@ struct llm_build_modern_bert : public llm_graph_context { x = ggml_get_rows(ctx0, x, inp_out_ids); } - // 5) pre-MLP norm (mlp_norm) + // pre-MLP norm (mlp_norm) ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il); cb(h, "mlp_pre_norm", il); - // 6) MLP (prefer GEGLU if gate exists or up has 2*n_ff rows) + // MLP (prefer GEGLU if gate exists or up has 2*n_ff rows) ggml_tensor * mlp_out = nullptr; const bool has_gate_tensor = (model.layers[il].ffn_gate != nullptr); const bool up_is_2x = (model.layers[il].ffn_up && model.layers[il].ffn_up->ne[0] == 2*hparams.n_ff()); @@ -7705,14 +7705,14 @@ struct llm_build_modern_bert : public llm_graph_context { cb(mlp_out, "ffn_out_gelu", il); } - // 7) Residual after MLP + // Residual after MLP ggml_tensor * cur_layer = ggml_add(ctx0, mlp_out, cur_attn); - // 8) feed into next layer + // feed into next layer inpL = cur_layer; } - // 9) final model norm (final_norm) + // final model norm (final_norm) cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); cb(cur, "final_norm", -1); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 92a21b6426f3f..0b6c8c73e2c50 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1816,7 +1816,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { LLAMA_LOG_WARN("%s: ************************************ \n", __func__); LLAMA_LOG_WARN("%s: \n", __func__); pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } else if (tokenizer_pre == "default" || tokenizer_pre == "modern-bert") { + } else if (tokenizer_pre == "default" || tokenizer_pre == "modern-bert") /* need to fix modern-bert pre tokenizer */ { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( tokenizer_pre == "llama3" || From 41b68643337f28da416bb64a45d5dcb0478ebcef Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Tue, 26 Aug 2025 12:33:11 -0400 Subject: [PATCH 05/26] cleanup --- src/llama-model.cpp | 37 ++++++++++++------------------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e31f5e51593da..2ede73e4eccbf 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2034,22 +2034,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_ctx_train = hparams.n_ctx_train; - LLAMA_LOG_INFO("n_head = %lld\n", (long long) n_head); - LLAMA_LOG_INFO("n_head_kv = %lld\n", (long long) n_head_kv); - LLAMA_LOG_INFO("n_embd = %lld\n", (long long) n_embd); - LLAMA_LOG_INFO("n_embd_k_gqa = %lld\n", (long long) n_embd_k_gqa); - LLAMA_LOG_INFO("n_embd_v_gqa = %lld\n", (long long) n_embd_v_gqa); - LLAMA_LOG_INFO("n_embd_head_k = %lld\n", (long long) n_embd_head_k); - LLAMA_LOG_INFO("n_embd_head_v = %lld\n", (long long) n_embd_head_v); - LLAMA_LOG_INFO("n_ff = %lld\n", (long long) n_ff); - LLAMA_LOG_INFO("n_embd_gqa = %lld\n", (long long) n_embd_gqa); - LLAMA_LOG_INFO("n_vocab = %lld\n", (long long) n_vocab); - LLAMA_LOG_INFO("n_token_types = %lld\n", (long long) n_token_types); - LLAMA_LOG_INFO("n_rot = %lld\n", (long long) n_rot); - LLAMA_LOG_INFO("n_expert = %lld\n", (long long) n_expert); - LLAMA_LOG_INFO("n_expert_used = %lld\n", (long long) n_expert_used); - LLAMA_LOG_INFO("n_ctx_train = %lld\n", (long long) n_ctx_train); - if (n_expert > 0 && hparams.n_expert_used == 0) { throw std::runtime_error("model has expert layers but no expert layers are used"); } @@ -7688,19 +7672,22 @@ struct llm_build_modern_bert : public llm_graph_context { if (has_gate_tensor || up_is_2x) { mlp_out = build_ffn( h, - model.layers[il].ffn_up, /*up_b*/ nullptr, /*up_shexp*/ nullptr, - model.layers[il].ffn_gate, /*gate_b*/ nullptr, /*gate_shexp*/ nullptr, - model.layers[il].ffn_down, /*down_b*/ nullptr, /*down_shexp*/ nullptr, - /*expert_scores*/ nullptr, + model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL, + model.layers[il].ffn_gate, /*gate_b*/ NULL, /*gate_shexp*/ NULL, + model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL, + /*expert_scores*/ NULL, LLM_FFN_GEGLU, LLM_FFN_PAR, il); cb(mlp_out, "ffn_out_geglu", il); } else { + + LLAMA_LOG_INFO("Ffn_up : {%lld, %lld}, ffn_down : {%lld, %lld}\n", model.layers[il].ffn_up->ne[0], model.layers[il].ffn_up->ne[1], + model.layers[il].ffn_down->ne[0], model.layers[il].ffn_down->ne[0]); mlp_out = build_ffn( h, - model.layers[il].ffn_up, /*up_b*/ nullptr, /*up_shexp*/ nullptr, - /*gate*/ nullptr, /*gate_b*/ nullptr, /*gate_shexp*/ nullptr, - model.layers[il].ffn_down, /*down_b*/ nullptr, /*down_shexp*/ nullptr, - /*expert_scores*/ nullptr, + model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL, + /*gate*/ NULL, /*gate_b*/ NULL, /*gate_shexp*/ NULL, + model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL, + /*expert_scores*/ NULL, LLM_FFN_GELU, LLM_FFN_SEQ, il); cb(mlp_out, "ffn_out_gelu", il); } @@ -7712,7 +7699,7 @@ struct llm_build_modern_bert : public llm_graph_context { inpL = cur_layer; } - // final model norm (final_norm) + // 9) final model norm (final_norm) cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); cb(cur, "final_norm", -1); From cc3d7abab4d0fe487bad39c6a91f940624c4089b Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Tue, 26 Aug 2025 12:38:38 -0400 Subject: [PATCH 06/26] continuing --- ggml/src/ggml.c | 4 ++++ src/llama-graph.cpp | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 55a76f8248c09..2505489b1e78a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3015,6 +3015,10 @@ struct ggml_tensor * ggml_mul_mat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { + + printf("Up: {%lld, %lld}\n", a->ne[0], a->ne[1]); + printf("Cur: {%lld, %lld}\n", b->ne[0], b->ne[1]); + GGML_ASSERT(ggml_can_mul_mat(a, b)); GGML_ASSERT(!ggml_is_transposed(a)); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a05de6e585831..ae8b150d286af 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -644,6 +644,8 @@ ggml_tensor * llm_graph_context::build_ffn( llm_ffn_op_type type_op, llm_ffn_gate_type type_gate, int il) const { + + ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur; cb(tmp, "ffn_up", il); @@ -1377,7 +1379,7 @@ ggml_tensor * llm_graph_context::build_attn( // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams LLAMA_LOG_INFO("ubatch.equal_seqs() = %d, n_seqs = %d\n", ubatch.equal_seqs(), ubatch.n_seqs); - assert(!ubatch.equal_seqs()); + // sassert(!ubatch.equal_seqs()); ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; From 4ceb828112150e2efe025d09c8a5a49cf54a8b85 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Tue, 26 Aug 2025 13:01:30 -0400 Subject: [PATCH 07/26] correct tensor shape for qkv --- src/llama-model.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2ede73e4eccbf..92ff8b876ff60 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2704,11 +2704,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); } - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_ff, n_embd} , 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_ff * 2}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_ff, n_embd}, 0); // [3072, 384] + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, 2 * n_ff}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); } } break; From 18c0c23ed89da13260c10e1086ca2fcffc4952d2 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Wed, 27 Aug 2025 15:32:20 -0400 Subject: [PATCH 08/26] fixed tensor mappings and working on buildin graph --- ggml/src/ggml.c | 3 -- gguf-py/gguf/tensor_mapping.py | 4 +- src/llama-model.cpp | 68 +++++++++++++++++++++++++--------- 3 files changed, 53 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2505489b1e78a..79c0e437d3691 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3015,9 +3015,6 @@ struct ggml_tensor * ggml_mul_mat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - - printf("Up: {%lld, %lld}\n", a->ne[0], a->ne[1]); - printf("Cur: {%lld, %lld}\n", b->ne[0], b->ne[1]); GGML_ASSERT(ggml_can_mul_mat(a, b)); GGML_ASSERT(!ggml_is_transposed(a)); diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2d3c16ab8496c..e775f0f575f18 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -367,7 +367,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2 "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert - "layers.{bid}.mlp.Wo", # modern bert + "layers.{bid}.mlp.Wi", # modern bert "transformer.layer.{bid}.ffn.lin1", # distillbert "transformer.h.{bid}.mlp.fc_in", # gpt-j "transformer.h.{bid}.mlp.linear_3", # refact @@ -467,7 +467,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2 "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert - "layers.{bid}.mlp.Wi", # modern bert + "layers.{bid}.mlp.Wo", # modern bert "transformer.layer.{bid}.ffn.lin2", # distillbert "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 92ff8b876ff60..6a8953af33337 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2708,8 +2708,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_ff, n_embd}, 0); // [3072, 384] - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, 2 * n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); } } break; @@ -7548,6 +7548,7 @@ struct llm_build_modern_bert : public llm_graph_context { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); // == n_head_kv * n_embd_head const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_ff = hparams.n_ff(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7667,30 +7668,63 @@ struct llm_build_modern_bert : public llm_graph_context { // MLP (prefer GEGLU if gate exists or up has 2*n_ff rows) ggml_tensor * mlp_out = nullptr; - const bool has_gate_tensor = (model.layers[il].ffn_gate != nullptr); - const bool up_is_2x = (model.layers[il].ffn_up && model.layers[il].ffn_up->ne[0] == 2*hparams.n_ff()); + ggml_tensor * ffn_gate_view = model.layers[il].ffn_gate; + ggml_tensor * ffn_up_view = model.layers[il].ffn_up; + + if (ffn_gate_view == nullptr && ffn_up_view) { + + // Case A: weight stored as (2*ffn, hidden) -> split rows into two (ffn x hidden) + if( ffn_up_view->ne[0] == 2 * n_ff and ffn_up_view->ne[1] == n_embd) { + + // top half, (ffn up) + ffn_up_view = ggml_view_2d(ctx0, model.layers[il].ffn_up, + /*ne0*/ n_ff, /*ne1*/ n_embd, + /*nb1*/ model.layers[il].ffn_up->nb[1], + /*offset_bytes*/ (size_t)0); + // bottom half (gate) + ffn_gate_view = ggml_view_2d(ctx0, model.layers[il].ffn_up, + /*ne0*/ n_ff, /*ne1*/ n_embd, + /*nb1*/ model.layers[il].ffn_up->nb[1], + /*offset_bytes*/ (size_t)n_ff * model.layers[il].ffn_up->nb[1]); + } + else if ( ffn_up_view->ne[0] == n_embd && ffn_up_view->ne[1] == 2 * n_ff) { + // top half + ffn_up_view = ggml_view_2d(ctx0, model.layers[il].ffn_up, + n_embd, n_ff, + model.layers[il].ffn_up->nb[1], + 0); + ffn_up_view = ggml_cont(ctx0, ffn_up_view); + + ffn_gate_view = ggml_view_2d(ctx0, model.layers[il].ffn_up, + n_embd, n_ff, + model.layers[il].ffn_up->nb[1], + n_ff * sizeof(float)); + ffn_gate_view = ggml_cont(ctx0, ffn_gate_view); + } + + ggml_tensor * ffn_down_view = model.layers[il].ffn_down; + LLAMA_LOG_INFO("ffn shapes: Up: {%lld, %lld}, Gate: {%lld, %lld}, Down: {%lld, %lld}", + ffn_up_view->ne[0], ffn_up_view->ne[1], ffn_gate_view->ne[0], ffn_gate_view->ne[1], ffn_down_view->ne[0], ffn_down_view->ne[1]); - if (has_gate_tensor || up_is_2x) { mlp_out = build_ffn( h, model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL, - model.layers[il].ffn_gate, /*gate_b*/ NULL, /*gate_shexp*/ NULL, + ffn_gate_view , /*gate_b*/ NULL, /*gate_shexp*/ NULL, model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL, /*expert_scores*/ NULL, - LLM_FFN_GEGLU, LLM_FFN_PAR, il); - cb(mlp_out, "ffn_out_geglu", il); + LLM_FFN_GEGLU, LLM_FFN_PAR, il + ); + cb(mlp_out, "ffn_out_geglu", il); } else { - - LLAMA_LOG_INFO("Ffn_up : {%lld, %lld}, ffn_down : {%lld, %lld}\n", model.layers[il].ffn_up->ne[0], model.layers[il].ffn_up->ne[1], - model.layers[il].ffn_down->ne[0], model.layers[il].ffn_down->ne[0]); mlp_out = build_ffn( h, - model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL, - /*gate*/ NULL, /*gate_b*/ NULL, /*gate_shexp*/ NULL, - model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL, - /*expert_scores*/ NULL, - LLM_FFN_GELU, LLM_FFN_SEQ, il); - cb(mlp_out, "ffn_out_gelu", il); + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GEGLU, LLM_FFN_PAR, il + ); + cb(mlp_out, "ffn_out_geglu", il); } // Residual after MLP From bffe3c9092311498e951821068eaa9c9fae10743 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Thu, 28 Aug 2025 11:15:10 -0400 Subject: [PATCH 09/26] tensor debugging now works -> (llama-eval-callback), instead of simulated gate split with views, GEGLU is now used which does exactly this --- src/llama-graph.cpp | 16 +++++++++++++++- src/llama-model.cpp | 36 ++++++++++++++++++++++++++---------- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index ae8b150d286af..972d37306c854 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -645,8 +645,11 @@ ggml_tensor * llm_graph_context::build_ffn( llm_ffn_gate_type type_gate, int il) const { - + LLAMA_LOG_INFO("building lora: up is {%lld, %lld}\n input is {%lld, %lld}\n", up->ne[0], up->ne[1], cur->ne[0], cur->ne[1]); + ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur; + LLAMA_LOG_INFO("Building FFN\n"); + LLAMA_LOG_INFO("built lora: tmp is {%lld, %lld}\n", tmp->ne[0], tmp->ne[1]); cb(tmp, "ffn_up", il); if (up_b) { @@ -669,6 +672,8 @@ ggml_tensor * llm_graph_context::build_ffn( case LLM_FFN_PAR: { cur = build_lora_mm(gate, cur); + LLAMA_LOG_INFO("built lora: cur is {%lld, %lld}\n", cur->ne[0], cur->ne[1]); + cb(cur, "ffn_gate", il); } break; } @@ -687,6 +692,10 @@ ggml_tensor * llm_graph_context::build_ffn( cur = tmp; } + if( gate && type_gate == LLM_FFN_PAR ) { + LLAMA_LOG_INFO("Gate Exists and In Paralell\n"); + } + switch (type_op) { case LLM_FFN_SILU: if (gate && type_gate == LLM_FFN_PAR) { @@ -735,6 +744,7 @@ ggml_tensor * llm_graph_context::build_ffn( case LLM_FFN_GEGLU: { cur = ggml_geglu(ctx0, cur); + LLAMA_LOG_INFO("geglu split: cur is {%lld, %lld}\n", cur->ne[0], cur->ne[1]); cb(cur, "ffn_geglu", il); } break; case LLM_FFN_REGLU: @@ -747,12 +757,16 @@ ggml_tensor * llm_graph_context::build_ffn( } if (gate && type_gate == LLM_FFN_PAR) { + LLAMA_LOG_INFO("cur @ tmp: cur is {%lld, %lld}\n tmp is {%lld, %lld}\n", cur->ne[0], cur->ne[1], tmp->ne[0], tmp->ne[1]); cur = ggml_mul(ctx0, cur, tmp); + LLAMA_LOG_INFO("res is {%lld, %lld}\n", cur->ne[0], cur->ne[1]); cb(cur, "ffn_gate_par", il); } if (down) { cur = build_lora_mm(down, cur); + LLAMA_LOG_INFO("built lora: cur is {%lld, %lld}\n", cur->ne[0], cur->ne[1]); + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6a8953af33337..c02a3078d7a48 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7672,7 +7672,7 @@ struct llm_build_modern_bert : public llm_graph_context { ggml_tensor * ffn_up_view = model.layers[il].ffn_up; if (ffn_gate_view == nullptr && ffn_up_view) { - + // Case A: weight stored as (2*ffn, hidden) -> split rows into two (ffn x hidden) if( ffn_up_view->ne[0] == 2 * n_ff and ffn_up_view->ne[1] == n_embd) { @@ -7685,33 +7685,49 @@ struct llm_build_modern_bert : public llm_graph_context { ffn_gate_view = ggml_view_2d(ctx0, model.layers[il].ffn_up, /*ne0*/ n_ff, /*ne1*/ n_embd, /*nb1*/ model.layers[il].ffn_up->nb[1], + /*offset_bytes*/ (size_t)n_ff * model.layers[il].ffn_up->nb[1]); } + + /* else if ( ffn_up_view->ne[0] == n_embd && ffn_up_view->ne[1] == 2 * n_ff) { // top half + LLAMA_LOG_INFO("Case B:\n"); ffn_up_view = ggml_view_2d(ctx0, model.layers[il].ffn_up, n_embd, n_ff, model.layers[il].ffn_up->nb[1], 0); + ffn_up_view = ggml_cont(ctx0, ffn_up_view); ffn_gate_view = ggml_view_2d(ctx0, model.layers[il].ffn_up, n_embd, n_ff, model.layers[il].ffn_up->nb[1], - n_ff * sizeof(float)); + n_ff * model.layers[il].ffn_up->nb[0]); ffn_gate_view = ggml_cont(ctx0, ffn_gate_view); } - - ggml_tensor * ffn_down_view = model.layers[il].ffn_down; - LLAMA_LOG_INFO("ffn shapes: Up: {%lld, %lld}, Gate: {%lld, %lld}, Down: {%lld, %lld}", - ffn_up_view->ne[0], ffn_up_view->ne[1], ffn_gate_view->ne[0], ffn_gate_view->ne[1], ffn_down_view->ne[0], ffn_down_view->ne[1]); - + */ + //ggml_tensor * ffn_down_view = model.layers[il].ffn_down; + //LLAMA_LOG_INFO("ffn shapes: Up: {%lld, %lld}, Gate: {%lld, %lld}, Down: {%lld, %lld}\n", + // ffn_up_view->ne[0], ffn_up_view->ne[1], ffn_gate_view->ne[0], ffn_gate_view->ne[1], ffn_down_view->ne[0], ffn_down_view->ne[1]); + /* + ggml_tensor * cur, + ggml_tensor * up, + ggml_tensor * up_b, + ggml_tensor * up_s, + ggml_tensor * gate, + ggml_tensor * gate_b, + ggml_tensor * gate_s, + ggml_tensor * down, + ggml_tensor * down_b, + ggml_tensor * down_s, + ggml_tensor * act_scales,*/ mlp_out = build_ffn( h, - model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL, - ffn_gate_view , /*gate_b*/ NULL, /*gate_shexp*/ NULL, + model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL, + NULL , /*gate_b*/ NULL, /*gate_shexp*/ NULL, model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL, - /*expert_scores*/ NULL, + /*act_scales*/ NULL, LLM_FFN_GEGLU, LLM_FFN_PAR, il ); cb(mlp_out, "ffn_out_geglu", il); From 8f328431a1032ef63bbdf386f18c8f18f4bcc088 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Thu, 28 Aug 2025 12:33:52 -0400 Subject: [PATCH 10/26] cleanup --- src/llama-graph.cpp | 14 ------ src/llama-model.cpp | 114 ++++++++++---------------------------------- 2 files changed, 25 insertions(+), 103 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 972d37306c854..1512869ec6817 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -645,11 +645,8 @@ ggml_tensor * llm_graph_context::build_ffn( llm_ffn_gate_type type_gate, int il) const { - LLAMA_LOG_INFO("building lora: up is {%lld, %lld}\n input is {%lld, %lld}\n", up->ne[0], up->ne[1], cur->ne[0], cur->ne[1]); ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur; - LLAMA_LOG_INFO("Building FFN\n"); - LLAMA_LOG_INFO("built lora: tmp is {%lld, %lld}\n", tmp->ne[0], tmp->ne[1]); cb(tmp, "ffn_up", il); if (up_b) { @@ -672,8 +669,6 @@ ggml_tensor * llm_graph_context::build_ffn( case LLM_FFN_PAR: { cur = build_lora_mm(gate, cur); - LLAMA_LOG_INFO("built lora: cur is {%lld, %lld}\n", cur->ne[0], cur->ne[1]); - cb(cur, "ffn_gate", il); } break; } @@ -692,10 +687,6 @@ ggml_tensor * llm_graph_context::build_ffn( cur = tmp; } - if( gate && type_gate == LLM_FFN_PAR ) { - LLAMA_LOG_INFO("Gate Exists and In Paralell\n"); - } - switch (type_op) { case LLM_FFN_SILU: if (gate && type_gate == LLM_FFN_PAR) { @@ -744,7 +735,6 @@ ggml_tensor * llm_graph_context::build_ffn( case LLM_FFN_GEGLU: { cur = ggml_geglu(ctx0, cur); - LLAMA_LOG_INFO("geglu split: cur is {%lld, %lld}\n", cur->ne[0], cur->ne[1]); cb(cur, "ffn_geglu", il); } break; case LLM_FFN_REGLU: @@ -757,16 +747,12 @@ ggml_tensor * llm_graph_context::build_ffn( } if (gate && type_gate == LLM_FFN_PAR) { - LLAMA_LOG_INFO("cur @ tmp: cur is {%lld, %lld}\n tmp is {%lld, %lld}\n", cur->ne[0], cur->ne[1], tmp->ne[0], tmp->ne[1]); cur = ggml_mul(ctx0, cur, tmp); - LLAMA_LOG_INFO("res is {%lld, %lld}\n", cur->ne[0], cur->ne[1]); cb(cur, "ffn_gate_par", il); } if (down) { cur = build_lora_mm(down, cur); - LLAMA_LOG_INFO("built lora: cur is {%lld, %lld}\n", cur->ne[0], cur->ne[1]); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c02a3078d7a48..897c58ac14b0c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2696,11 +2696,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for(int i = 0; i < n_layer; ++i) { auto& layer = layers[i]; - // layer 0 uses identity so we dont need weights for said layer if ( i != 0 ) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + } else{ + // layer 0 uses identity so we dont need weights for said layer layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); } @@ -7546,14 +7547,14 @@ struct llm_build_modern_bert : public llm_graph_context { const int64_t n_head = hparams.n_head(); const int64_t n_head_kv = hparams.n_head_kv(); const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); // == n_head_kv * n_embd_head + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_tokens = ubatch.n_tokens; const int64_t n_ff = hparams.n_ff(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); // RoPE params - const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; // ModernBERT uses rotary + const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; // uses rotary const int32_t n_rot = hparams.n_rot; const int32_t n_ctx_orig = hparams.n_ctx_train; @@ -7561,7 +7562,7 @@ struct llm_build_modern_bert : public llm_graph_context { ggml_tensor * inpL; ggml_tensor * inp_pos = nullptr; - // ModernBERT needs positions for RoPE + // needs positions for RoPE inp_pos = build_inp_pos(); // embeddings (token + optional type), NO absolute pos embed @@ -7583,7 +7584,7 @@ struct llm_build_modern_bert : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * x = inpL; - // pre-attention norm (attn_norm). Layer 0 may be Identity() -> nullptr + // pre attention norm (attn_norm). Layer 0 may be Identity() -> nullptr ggml_tensor * x_attn_in = x; if (model.layers[il].attn_norm) { x_attn_in = build_norm(x, @@ -7592,6 +7593,7 @@ struct llm_build_modern_bert : public llm_graph_context { LLM_NORM, il); cb(x_attn_in, "attn_pre_norm", il); } else { + LLAMA_LOG_INFO("Identity Tensor\n"); cb(x_attn_in, "attn_pre_norm_identity", il); } @@ -7601,7 +7603,7 @@ struct llm_build_modern_bert : public llm_graph_context { ggml_tensor * Kcur; ggml_tensor * Vcur; - GGML_ASSERT(model.layers[il].wqkv); // ModernBERT uses fused QKV + GGML_ASSERT(model.layers[il].wqkv); // fused QKV qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in); cb(qkv, "wqkv", il); @@ -7615,7 +7617,7 @@ struct llm_build_modern_bert : public llm_graph_context { Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd))); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - // Optional per Q/K + // optional per Q/K if (model.layers[il].attn_q_norm) { Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); } @@ -7623,12 +7625,12 @@ struct llm_build_modern_bert : public llm_graph_context { Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); } - // Heads + // heads Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // RoPE (NEOX) on Q and K + // RoPE (NEOX ... maybe?) on Q and K Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -7650,99 +7652,33 @@ struct llm_build_modern_bert : public llm_graph_context { il); cb(attn_out, "attn_out", il); - // Residual after attention + // residual after attention ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x); - // If we subselect outputs, do it at the last layer after attn resid + // ifwe subselect outputs, do it at the last layer after attn resid if (il == n_layer - 1 && inp_out_ids) { cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids); x = ggml_get_rows(ctx0, x, inp_out_ids); } - // pre-MLP norm (mlp_norm) + // pre mlp norm ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il); cb(h, "mlp_pre_norm", il); - // MLP (prefer GEGLU if gate exists or up has 2*n_ff rows) - ggml_tensor * mlp_out = nullptr; - ggml_tensor * ffn_gate_view = model.layers[il].ffn_gate; - ggml_tensor * ffn_up_view = model.layers[il].ffn_up; - - if (ffn_gate_view == nullptr && ffn_up_view) { - - // Case A: weight stored as (2*ffn, hidden) -> split rows into two (ffn x hidden) - if( ffn_up_view->ne[0] == 2 * n_ff and ffn_up_view->ne[1] == n_embd) { - - // top half, (ffn up) - ffn_up_view = ggml_view_2d(ctx0, model.layers[il].ffn_up, - /*ne0*/ n_ff, /*ne1*/ n_embd, - /*nb1*/ model.layers[il].ffn_up->nb[1], - /*offset_bytes*/ (size_t)0); - // bottom half (gate) - ffn_gate_view = ggml_view_2d(ctx0, model.layers[il].ffn_up, - /*ne0*/ n_ff, /*ne1*/ n_embd, - /*nb1*/ model.layers[il].ffn_up->nb[1], - - /*offset_bytes*/ (size_t)n_ff * model.layers[il].ffn_up->nb[1]); - } - - /* - else if ( ffn_up_view->ne[0] == n_embd && ffn_up_view->ne[1] == 2 * n_ff) { - // top half - LLAMA_LOG_INFO("Case B:\n"); - ffn_up_view = ggml_view_2d(ctx0, model.layers[il].ffn_up, - n_embd, n_ff, - model.layers[il].ffn_up->nb[1], - 0); - - ffn_up_view = ggml_cont(ctx0, ffn_up_view); - - ffn_gate_view = ggml_view_2d(ctx0, model.layers[il].ffn_up, - n_embd, n_ff, - model.layers[il].ffn_up->nb[1], - n_ff * model.layers[il].ffn_up->nb[0]); - ffn_gate_view = ggml_cont(ctx0, ffn_gate_view); - } - */ - //ggml_tensor * ffn_down_view = model.layers[il].ffn_down; - //LLAMA_LOG_INFO("ffn shapes: Up: {%lld, %lld}, Gate: {%lld, %lld}, Down: {%lld, %lld}\n", - // ffn_up_view->ne[0], ffn_up_view->ne[1], ffn_gate_view->ne[0], ffn_gate_view->ne[1], ffn_down_view->ne[0], ffn_down_view->ne[1]); - /* - ggml_tensor * cur, - ggml_tensor * up, - ggml_tensor * up_b, - ggml_tensor * up_s, - ggml_tensor * gate, - ggml_tensor * gate_b, - ggml_tensor * gate_s, - ggml_tensor * down, - ggml_tensor * down_b, - ggml_tensor * down_s, - ggml_tensor * act_scales,*/ - mlp_out = build_ffn( - h, - model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL, - NULL , /*gate_b*/ NULL, /*gate_shexp*/ NULL, - model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL, - /*act_scales*/ NULL, - LLM_FFN_GEGLU, LLM_FFN_PAR, il - ); - cb(mlp_out, "ffn_out_geglu", il); - } else { - mlp_out = build_ffn( - h, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_GEGLU, LLM_FFN_PAR, il - ); - cb(mlp_out, "ffn_out_geglu", il); - } + // GEGLU because we will split ffn_up which has shape [n_embd, n_ff * 2] and ffn_down has shape [n_ff, n_embd] + ggml_tensor * mlp_out = build_ffn( + h, + model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL, + /*gate*/ NULL , /*gate_b*/ NULL, /*gate_shexp*/ NULL, + model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL, + /*act_scales*/ NULL, + LLM_FFN_GEGLU, LLM_FFN_PAR, il + ); + cb(mlp_out, "ffn_out_geglu", il); // Residual after MLP ggml_tensor * cur_layer = ggml_add(ctx0, mlp_out, cur_attn); @@ -7750,7 +7686,7 @@ struct llm_build_modern_bert : public llm_graph_context { inpL = cur_layer; } - // 9) final model norm (final_norm) + // final model norm (final_norm) cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); cb(cur, "final_norm", -1); From 9805635c122e3b632a95b6bdd111f71578e103f9 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Thu, 28 Aug 2025 12:36:26 -0400 Subject: [PATCH 11/26] cleanup --- convert_hf_to_gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6ed2587f8a9b5..7031e161e9ef7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -133,7 +133,6 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: self.ftype = gguf.LlamaFileType.MOSTLY_BF16 # Configure GGUF Writer - print(f"arch: {gguf.MODEL_ARCH_NAMES[self.model_arch]}") self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) From 40249dd5ec357892dcd85ac583186770dc4187c4 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Thu, 28 Aug 2025 12:37:02 -0400 Subject: [PATCH 12/26] cleanup --- convert_hf_to_gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7031e161e9ef7..3c483e10287f8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -465,7 +465,6 @@ def print_registered_models(cls): @classmethod def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type[ModelBase]: try: - print(f"model_type: {model_type}, arch: {arch}") return cls._model_classes[model_type][arch] except KeyError: raise NotImplementedError(f'Architecture {arch!r} not supported!') from None From 853f344cfe5220b7be6bc6fdb1e3d436c049f07e Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Thu, 28 Aug 2025 12:47:10 -0400 Subject: [PATCH 13/26] more cleanup --- ggml/src/ggml.c | 1 - src/llama-model.cpp | 24 ++++++++++-------------- src/llama-vocab.cpp | 3 --- src/llama.cpp | 1 - 4 files changed, 10 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 79c0e437d3691..55a76f8248c09 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3015,7 +3015,6 @@ struct ggml_tensor * ggml_mul_mat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - GGML_ASSERT(ggml_can_mul_mat(a, b)); GGML_ASSERT(!ggml_is_transposed(a)); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 897c58ac14b0c..860e55859552a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -451,7 +451,6 @@ void llama_model::load_arch(llama_model_loader & ml) { } void llama_model::load_hparams(llama_model_loader & ml) { - const gguf_context * ctx = ml.meta.get(); // get metadata as string @@ -465,7 +464,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { gguf_kv.emplace(name, value); } - // get general kv ml.get_key(LLM_KV_GENERAL_NAME, name, false); @@ -586,7 +584,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } // arch-specific KVs - LLAMA_LOG_INFO("Switching Arch\n"); switch (arch) { case LLM_ARCH_LLAMA: { @@ -1901,6 +1898,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { void llama_model::load_vocab(llama_model_loader & ml) { const auto kv = LLM_KV(arch); + vocab.load(ml, kv); } @@ -2045,7 +2043,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); - LLAMA_LOG_INFO("Creating Tensor: %s\n", tn.str().c_str()); + if (!t_meta) { if (flags & TENSOR_NOT_REQUIRED) { return nullptr; @@ -2120,6 +2118,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } ggml_backend_buffer_type_t buft = nullptr; + // check overrides if (ml.tensor_buft_overrides) { std::string tensor_name = tn.str(); @@ -2167,6 +2166,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { first_moved_to_buft = buft; } } + ggml_context * ctx = ctx_for_buft(buft); // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one @@ -2624,14 +2624,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); - if (arch == LLM_ARCH_BERT) { pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); @@ -2639,11 +2636,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); } + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); if (!layer.wqkv) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); @@ -2657,8 +2657,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -2668,7 +2667,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); } else { - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); @@ -2683,7 +2681,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); - } } break; case LLM_ARCH_MODERN_BERT: @@ -7549,7 +7546,6 @@ struct llm_build_modern_bert : public llm_graph_context { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_ff = hparams.n_ff(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 0b6c8c73e2c50..21420389ec9c1 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1661,13 +1661,10 @@ struct llama_vocab::impl { void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { struct gguf_context * ctx = ml.meta.get(); - LLAMA_LOG_INFO("Determining Vocab Type\n"); // determine vocab type { ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); - LLAMA_LOG_INFO("pre tokenizer model: %s\n", tokenizer_pre.c_str()); - LLAMA_LOG_INFO("tokenizer model: %s\n", tokenizer_model.c_str()); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false); diff --git a/src/llama.cpp b/src/llama.cpp index 024e142453768..34906cdb62844 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -126,7 +126,6 @@ static int llama_model_load(const std::string & fname, std::vector if (!model.load_tensors(ml)) { return -2; } - } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); return -1; From 2a1c75047c1844bd53eba97b5b1bef6813cf3cae Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Thu, 28 Aug 2025 12:59:42 -0400 Subject: [PATCH 14/26] ubatch issues, the assert for checking equal seqs in llama-graph.cpp when building attention keeps failing, setting ubatch size to 1 when running llama-embedding with --ubatch-size 1 makes it work, but needs to be looked into more --- src/llama-graph.cpp | 8 +++----- src/llama-model.cpp | 1 - 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1512869ec6817..9ca2e579d7299 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -644,8 +644,6 @@ ggml_tensor * llm_graph_context::build_ffn( llm_ffn_op_type type_op, llm_ffn_gate_type type_gate, int il) const { - - ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur; cb(tmp, "ffn_up", il); @@ -1377,9 +1375,9 @@ ggml_tensor * llm_graph_context::build_attn( // [TAG_NO_CACHE_PAD] // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams - LLAMA_LOG_INFO("ubatch.equal_seqs() = %d, n_seqs = %d\n", ubatch.equal_seqs(), ubatch.n_seqs); - - // sassert(!ubatch.equal_seqs()); + if (ubatch.n_seqs > 1) { + assert(!ubatch.equal_seqs()); + } ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 860e55859552a..88784ddadd48a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7589,7 +7589,6 @@ struct llm_build_modern_bert : public llm_graph_context { LLM_NORM, il); cb(x_attn_in, "attn_pre_norm", il); } else { - LLAMA_LOG_INFO("Identity Tensor\n"); cb(x_attn_in, "attn_pre_norm_identity", il); } From c73eb685fd4e6eee0a4b43e8b6e50a902ab19e31 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Fri, 29 Aug 2025 12:15:31 -0400 Subject: [PATCH 15/26] added cls token per previous modern bert attempt, still working on checking out the rest --- gguf-py/gguf/constants.py | 2 ++ src/llama-arch.cpp | 2 ++ src/llama-model.cpp | 5 +++++ 3 files changed, 9 insertions(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1273ca31d5830..607486a31a37d 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1185,6 +1185,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 031b4c486f609..9a009ac902ccc 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -519,6 +519,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, }, }, { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 88784ddadd48a..a159eb347201d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2710,6 +2710,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); } + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } break; case LLM_ARCH_NEO_BERT: { From ca353d37b459c3da1db8df3ccf9b2cdceda9f281 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Tue, 2 Sep 2025 12:26:20 -0400 Subject: [PATCH 16/26] fixed pre tokenizer and still working through previous pr --- gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ src/llama-vocab.cpp | 3 ++- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 607486a31a37d..d24a898612b20 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -154,6 +154,7 @@ class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" FREQ_BASE = "{arch}.rope.freq_base" + FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" SCALING_TYPE = "{arch}.rope.scaling.type" SCALING_FACTOR = "{arch}.rope.scaling.factor" SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a6cc8a931eb27..3cc05f3be6494 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -814,6 +814,9 @@ def add_iclr_lora_rank(self, length: int) -> None: def add_value_residual_mix_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length) + def add_rope_freq_base_swa(self, value: float) -> None: + self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value) + def add_gate_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 21420389ec9c1..426f7f7cf6ad7 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1857,7 +1857,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || - tokenizer_pre == "mellum") { + tokenizer_pre == "mellum" || + tokenizer_pre == "modern-bert") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "jina-v1-en" || From 6d86944cb494bdc105b86fabfdcc0d4a06c19dc3 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Wed, 3 Sep 2025 14:32:39 -0400 Subject: [PATCH 17/26] working through previous attemp, implimented more accurate conversion per previous attempt, added local sliding window attention that alternates every third layer --- convert_hf_to_gguf.py | 45 ++++----- src/llama-hparams.h | 1 + src/llama-kv-cache-unified.cpp | 12 +++ src/llama-model.cpp | 168 +++++++++++++-------------------- 4 files changed, 101 insertions(+), 125 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3c483e10287f8..cff934ba23f5c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8308,37 +8308,32 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") -@ModelBase.register("ModernBertModel") -class ModernBertModel(TextModel): +@ModelBase.register("ModernBertModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification") +class ModernBertModel(BertModel): model_arch = gguf.MODEL_ARCH.MODERN_BERT - def set_gguf_parameters(self) -> None: - # Determine block count (number of hidden layers) - block_count = self.hparams.get("num_hidden_layers") or self.hparams.get("num_hidden_layers_alt") - if block_count is None: - raise ValueError("Could not determine number of hidden layers from hparams") + def set_vocab(self): + self._set_vocab_gpt2() + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) - # Attention heads and dimensions - n_head = self.hparams.get("num_attention_heads") - if n_head is None: - raise ValueError("Missing 'num_attention_heads' in hparams") + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_sliding_window(self.hparams["local_attention"]) + self.gguf_writer.add_rope_freq_base(self.hparams["global_rope_theta"]) + self.gguf_writer.add_rope_freq_base_swa(self.hparams["local_rope_theta"]) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) - hidden_size = self.hparams["hidden_size"] - head_dim = hidden_size // n_head - ffn_dim = self.hparams.get("intermediate_size", 4 * hidden_size) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # These layers act as MLM head, so we don't need them + if name.startswith("decoder."): + return [] - # GGUF parameter assignment - self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 512)) - self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_feed_forward_length(ffn_dim) - self.gguf_writer.add_block_count(block_count) - self.gguf_writer.add_head_count(n_head) - self.gguf_writer.add_layer_norm_eps(self.hparams.get("layer_norm_eps", 1e-12)) - self.gguf_writer.add_file_type(self.ftype) + if name.startswith("model."): + name = name[6:] - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # Directly map tensor names without QKV splitting or reordering - return [(self.map_tensor_name(name), data_torch)] + return super().modify_tensors(data_torch, name, bid) ###### CONVERSION LOGIC ###### diff --git a/src/llama-hparams.h b/src/llama-hparams.h index bd23122443271..2e13a7732c96e 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -19,6 +19,7 @@ enum llama_swa_type { LLAMA_SWA_TYPE_NONE = 0, LLAMA_SWA_TYPE_STANDARD = 1, LLAMA_SWA_TYPE_CHUNKED = 2, + LLAMA_SWA_TYPE_LOCAL = 3, }; struct llama_hparams_posnet { diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index e539142e6b8cd..678a7d23ade03 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1807,6 +1807,18 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { return true; } } break; + case LLAMA_SWA_TYPE_LOCAL: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // mask if outside the window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + + } return false; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a159eb347201d..6f70335647678 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -759,11 +759,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_MODERN_BERT: { - //ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - LLAMA_LOG_INFO("Switching Modern Bert Arch\n"); + + hparams.swa_type = LLAMA_SWA_TYPE_LOCAL; + + hparams.set_swa_pattern(3, 0); + hparams.n_swa = 128; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + switch (hparams.n_layer) { case 12: - type = LLM_TYPE_47M; break; // granite-embeddings-mall + type = LLM_TYPE_47M; break; // granite-embeddings-small default: type = LLM_TYPE_UNKNOWN; } } break; @@ -7544,152 +7553,111 @@ struct llm_build_bert : public llm_graph_context { struct llm_build_modern_bert : public llm_graph_context { llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_head = hparams.n_head(); - const int64_t n_head_kv = hparams.n_head_kv(); - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_head = hparams.n_head(); + const int64_t n_head_kv = hparams.n_head_kv(); + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - // RoPE params - const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; // uses rotary - const int32_t n_rot = hparams.n_rot; - const int32_t n_ctx_orig = hparams.n_ctx_train; - - ggml_tensor * cur; - ggml_tensor * inpL; - ggml_tensor * inp_pos = nullptr; - - // needs positions for RoPE - inp_pos = build_inp_pos(); + // rope params + const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; + const int32_t n_rot = hparams.n_rot; + const int32_t n_ctx_orig = hparams.n_ctx_train; + const float freq_base = hparams.rope_freq_base_train; + const float freq_scale = hparams.rope_freq_scale_train; + const float attn_factor = 1.0f; + const float ext_factor = 1.0f; + const float beta_fast = 0.0f; + const float beta_slow = 0.0f; - // embeddings (token + optional type), NO absolute pos embed - inpL = build_inp_embd(model.tok_embd); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inpL = build_inp_embd(model.tok_embd); if (model.type_embd) { - ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); - inpL = ggml_add(ctx0, inpL, type_row0); + inpL = ggml_add(ctx0, inpL, ggml_view_1d(ctx0, model.type_embd, n_embd, 0)); } - cb(inpL, "inp_embd", -1); - - // embeddings LayerNorm (embeddings.norm) inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); - auto * inp_attn = build_attn_inp_no_cache(); + auto * inp_attn = build_attn_inp_no_cache(); ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { ggml_tensor * x = inpL; - // pre attention norm (attn_norm). Layer 0 may be Identity() -> nullptr + // Pre attention Layer norm ggml_tensor * x_attn_in = x; if (model.layers[il].attn_norm) { - x_attn_in = build_norm(x, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, il); - cb(x_attn_in, "attn_pre_norm", il); - } else { - cb(x_attn_in, "attn_pre_norm_identity", il); + x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il); } - // Attention: fused Wqkv -> split -> heads -> RoPE(Q,K) -> attn -> Wo - ggml_tensor * qkv = nullptr; - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - GGML_ASSERT(model.layers[il].wqkv); // fused QKV - qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in); - cb(qkv, "wqkv", il); - + // fused qkv + GGML_ASSERT(model.layers[il].wqkv); + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in); if (model.layers[il].bqkv) { qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv); - cb(qkv, "bqkv", il); } - // Fused layout: [ (n_embd + 2*n_embd_gqa), n_tokens ] - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd))); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0)); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd)); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd + n_embd_gqa)); - // optional per Q/K - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); - } - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - } + // optional q/k LayerNorm + if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); + if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - // heads + // reshape for multi head Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // RoPE (NEOX ... maybe?) on Q and K + // rope embedding Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - cb(Qcur, "Qcur_rope", il); - cb(Kcur, "Kcur_rope", il); - cb(Vcur, "Vcur", il); + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); ggml_tensor * attn_out = build_attn( inp_attn, - model.layers[il].wo, model.layers[il].bo, // Wo, optional bias + model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, - /*K_cache*/ nullptr, - /*V_cache*/ nullptr, + /*k cache*/ nullptr, + /*v cache*/ nullptr, 1.0f / sqrtf(float(n_embd_head)), - il); - cb(attn_out, "attn_out", il); + il + ); - // residual after attention ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x); - // ifwe subselect outputs, do it at the last layer after attn resid + // optional subselect output tokens (inp_out_ids) if (il == n_layer - 1 && inp_out_ids) { - cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids); - x = ggml_get_rows(ctx0, x, inp_out_ids); + cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } - // pre mlp norm - ggml_tensor * h = build_norm(cur_attn, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, il); - cb(h, "mlp_pre_norm", il); + // pre mlp LayerNorm + ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il); - // GEGLU because we will split ffn_up which has shape [n_embd, n_ff * 2] and ffn_down has shape [n_ff, n_embd] + // geglu FFN ggml_tensor * mlp_out = build_ffn( h, - model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL, - /*gate*/ NULL , /*gate_b*/ NULL, /*gate_shexp*/ NULL, - model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL, - /*act_scales*/ NULL, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_GEGLU, LLM_FFN_PAR, il ); - cb(mlp_out, "ffn_out_geglu", il); - // Residual after MLP - ggml_tensor * cur_layer = ggml_add(ctx0, mlp_out, cur_attn); - - // feed into next layer - inpL = cur_layer; + // resid addition + inpL = ggml_add(ctx0, mlp_out, cur_attn); } - // final model norm (final_norm) - cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); - cb(cur, "final_norm", -1); - + ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur); } From 39c029144b20973056f89e743b8cf20c8fe5382d Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Wed, 3 Sep 2025 14:34:51 -0400 Subject: [PATCH 18/26] fixed pre tokenizer --- src/llama-vocab.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 426f7f7cf6ad7..32683879083e0 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1813,7 +1813,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { LLAMA_LOG_WARN("%s: ************************************ \n", __func__); LLAMA_LOG_WARN("%s: \n", __func__); pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } else if (tokenizer_pre == "default" || tokenizer_pre == "modern-bert") /* need to fix modern-bert pre tokenizer */ { + } else if (tokenizer_pre == "default") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( tokenizer_pre == "llama3" || From e101005d1af7ff184234a6f7d0ca054ef2dfb2dd Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Sun, 7 Sep 2025 21:00:38 -0400 Subject: [PATCH 19/26] working on swa with local and global alternating attention --- src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + src/llama-model.cpp | 130 +++++++++++++++++++++++++++++++++++++------- 3 files changed, 113 insertions(+), 19 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 9a009ac902ccc..cbb1f3d8f64ff 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -171,6 +171,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index c99448e78f481..8422cbe2a1726 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -176,6 +176,7 @@ enum llm_kv { LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, LLM_KV_ROPE_SCALE_LINEAR, + LLM_KV_ROPE_FREQ_BASE_SWA, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, LLM_KV_ROPE_SCALING_ATTN_FACTOR, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6f70335647678..a3b4646f0b5ef 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7559,6 +7559,7 @@ struct llm_build_modern_bert : public llm_graph_context { const int64_t n_head_kv = hparams.n_head_kv(); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + const int64_t n_local_swa = hparams.n_swa; const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7574,7 +7575,17 @@ struct llm_build_modern_bert : public llm_graph_context { const float beta_fast = 0.0f; const float beta_slow = 0.0f; - ggml_tensor * inp_pos = build_inp_pos(); + + ggml_tensor *inp_pos_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 4096, 1); + ggml_set_input(inp_pos_global); + size_t element_size = ggml_type_size(inp_pos_global->type); + + size_t nb1 = element_size; + size_t nb2 = nb1; + + inp_pos_global = ggml_view_3d(ctx0, inp_pos_global, 1, 1, 4096, nb1, nb2, 0); + inp_pos_global = ggml_cont(ctx0, inp_pos_global); + ggml_tensor * inpL = build_inp_embd(model.tok_embd); if (model.type_embd) { @@ -7582,19 +7593,20 @@ struct llm_build_modern_bert : public llm_graph_context { } inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - auto * inp_attn = build_attn_inp_no_cache(); + auto * inp_attn = build_attn_inp_kv_unified_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * x = inpL; - // Pre attention Layer norm + // pre attn LayerNorm ggml_tensor * x_attn_in = x; if (model.layers[il].attn_norm) { x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il); } - // fused qkv + // fused QKV GGML_ASSERT(model.layers[il].wqkv); ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in); if (model.layers[il].bqkv) { @@ -7609,41 +7621,120 @@ struct llm_build_modern_bert : public llm_graph_context { if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - // reshape for multi head + // reshape for multi-head attention Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // rope embedding - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + // global or local layer + bool is_global = ((il + 1) % 3 == 0); + float freq_base_l = is_global ? 160000.0f : 10000.0f; // rope theta + float freq_scale_l = 1.0f; + + ggml_tensor * pos_q = inp_pos_global; + + ggml_tensor * K_work = Kcur; + ggml_tensor * V_work = Vcur; + ggml_tensor * pos_k = inp_pos_global; + + if (!is_global) { + ggml_tensor * idx_src = inp_attn->self_k_idxs_swa; + + ggml_tensor * idx_view1d = ggml_view_1d(ctx0, idx_src, idx_src->ne[0], 0); + ggml_tensor * idx_cont = ggml_cont(ctx0, idx_view1d); + + ggml_tensor * idx_i32 = idx_cont; + if (idx_i32->type != GGML_TYPE_I32) { + idx_i32 = ggml_cast(ctx0, idx_cont, GGML_TYPE_I32); + } + + const int64_t n_indices = idx_i32->ne[0]; + ggml_tensor * idx_2d = ggml_view_2d(ctx0, idx_i32, 1, n_indices, sizeof(int32_t), 0); + + idx_2d = ggml_cont(ctx0, idx_2d); + if (idx_2d->type != GGML_TYPE_I32) idx_2d = ggml_cast(ctx0, idx_2d, GGML_TYPE_I32); + + Kcur->ne[0], Kcur->ne[1], Kcur->ne[2], + idx_2d->ne[0], idx_2d->ne[1], idx_2d->ne[2], idx_2d->ne[3], + idx_2d->type); + + K_work = ggml_get_rows(ctx0, Kcur, idx_2d); + V_work = ggml_get_rows(ctx0, Vcur, idx_2d); + + + + ggml_tensor * pos_rows = ggml_get_rows(ctx0, inp_pos_global, idx_2d); + + if (!ggml_is_vector(pos_rows)) { + const int64_t n_el = ggml_nelements(pos_rows); + pos_rows = ggml_view_1d(ctx0, pos_rows, n_el, 0); + pos_rows = ggml_cont(ctx0, pos_rows); + } else { + pos_rows = ggml_cont(ctx0, pos_rows); + } + // ensure I32 + if (pos_rows->type != GGML_TYPE_I32) { + pos_rows = ggml_cast(ctx0, pos_rows, GGML_TYPE_I32); + } + + // final pos_k to pass to rope + pos_k = pos_rows; + LLAMA_LOG_INFO("pos_k final: ne[0]=%lld, type=%d\n", pos_k->ne[0], pos_k->type); + } + + if( !ggml_is_vector(pos_q) ) { + const int64_t n_el = ggml_nelements(pos_q); + pos_q = ggml_view_1d(ctx0, pos_q, n_el, 0); + pos_q = ggml_cont(ctx0, pos_q); + } + if( !ggml_is_vector(pos_q) ) { + } + + + // apply rope + Qcur = ggml_rope_ext(ctx0, Qcur, pos_q, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + if( !ggml_is_vector(pos_k) ) { + const int64_t n_el = ggml_nelements(pos_k); + pos_k = ggml_view_1d(ctx0, pos_k, n_el, 0); + pos_k = ggml_cont(ctx0, pos_k); + } + + K_work = ggml_rope_ext(ctx0, K_work, pos_k, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + // choseing mask, global vs swa + ggml_tensor * kq_b_layer = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa; + ggml_tensor * attn_out = build_attn( inp_attn, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, - /*k cache*/ nullptr, - /*v cache*/ nullptr, + model.layers[il].wo, + model.layers[il].bo, + Qcur, + K_work, + V_work, + kq_b_layer, + nullptr, 1.0f / sqrtf(float(n_embd_head)), il ); + // residual addition ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x); - // optional subselect output tokens (inp_out_ids) + // optional output select if (il == n_layer - 1 && inp_out_ids) { cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids); inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } - // pre mlp LayerNorm + // pre mlp layer norm ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il); - // geglu FFN + // geglu ffn ggml_tensor * mlp_out = build_ffn( h, model.layers[il].ffn_up, NULL, NULL, @@ -7653,10 +7744,11 @@ struct llm_build_modern_bert : public llm_graph_context { LLM_FFN_GEGLU, LLM_FFN_PAR, il ); - // resid addition + // resudi addition after FFN inpL = ggml_add(ctx0, mlp_out, cur_attn); } + ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur); From 044bc7d5cddaece3a1c14eadfb59016efe4ec47e Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Mon, 8 Sep 2025 12:21:18 -0400 Subject: [PATCH 20/26] some cleanup and now fails on build attn --- src/llama-graph.cpp | 2 +- src/llama-model.cpp | 16 +++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9ca2e579d7299..8760046c843b7 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1547,7 +1547,7 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks( // optionally store to KV cache if (k_cur) { const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs(); - + LLAMA_LOG_INFO("k_cur.shape = {%lld, %lld, %lld, %lld}\n", k_cur->ne[0], k_cur->ne[1], k_cur->ne[2], k_cur->ne[3]); ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a3b4646f0b5ef..8966cdcf1216f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7654,10 +7654,6 @@ struct llm_build_modern_bert : public llm_graph_context { idx_2d = ggml_cont(ctx0, idx_2d); if (idx_2d->type != GGML_TYPE_I32) idx_2d = ggml_cast(ctx0, idx_2d, GGML_TYPE_I32); - Kcur->ne[0], Kcur->ne[1], Kcur->ne[2], - idx_2d->ne[0], idx_2d->ne[1], idx_2d->ne[2], idx_2d->ne[3], - idx_2d->type); - K_work = ggml_get_rows(ctx0, Kcur, idx_2d); V_work = ggml_get_rows(ctx0, Vcur, idx_2d); @@ -7679,7 +7675,7 @@ struct llm_build_modern_bert : public llm_graph_context { // final pos_k to pass to rope pos_k = pos_rows; - LLAMA_LOG_INFO("pos_k final: ne[0]=%lld, type=%d\n", pos_k->ne[0], pos_k->type); + LLAMA_LOG_INFO("pos_k final: ne[0]=%lld, ne[1]=%lld type=%d\n", pos_k->ne[0], pos_k->ne[1], pos_k->type); } if( !ggml_is_vector(pos_q) ) { @@ -7707,7 +7703,9 @@ struct llm_build_modern_bert : public llm_graph_context { ext_factor, attn_factor, beta_fast, beta_slow); // choseing mask, global vs swa - ggml_tensor * kq_b_layer = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa; + ggml_tensor * kq_mask = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa; + + ggml_tensor * attn_out = build_attn( inp_attn, @@ -7716,14 +7714,14 @@ struct llm_build_modern_bert : public llm_graph_context { Qcur, K_work, V_work, - kq_b_layer, + kq_mask, nullptr, 1.0f / sqrtf(float(n_embd_head)), il ); - // residual addition - ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x); + + ggml_tensor * cur_attn = ggml_add(ctx0, x, attn_out); // optional output select if (il == n_layer - 1 && inp_out_ids) { From e296a0b6e694283ef0fc52c1bf9a780cff930c77 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Mon, 8 Sep 2025 15:38:13 -0400 Subject: [PATCH 21/26] starting to work, and some cleanup, currently failing on last layer construction in graph build --- src/llama-graph.cpp | 2 +- src/llama-model.cpp | 21 +++++++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8760046c843b7..9ca2e579d7299 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1547,7 +1547,7 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks( // optionally store to KV cache if (k_cur) { const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs(); - LLAMA_LOG_INFO("k_cur.shape = {%lld, %lld, %lld, %lld}\n", k_cur->ne[0], k_cur->ne[1], k_cur->ne[2], k_cur->ne[3]); + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8966cdcf1216f..34cd49083be51 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7598,6 +7598,7 @@ struct llm_build_modern_bert : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { + LLAMA_LOG_INFO("Setting layer %d\n", il); ggml_tensor * x = inpL; // pre attn LayerNorm @@ -7656,8 +7657,6 @@ struct llm_build_modern_bert : public llm_graph_context { K_work = ggml_get_rows(ctx0, Kcur, idx_2d); V_work = ggml_get_rows(ctx0, Vcur, idx_2d); - - ggml_tensor * pos_rows = ggml_get_rows(ctx0, inp_pos_global, idx_2d); @@ -7675,7 +7674,6 @@ struct llm_build_modern_bert : public llm_graph_context { // final pos_k to pass to rope pos_k = pos_rows; - LLAMA_LOG_INFO("pos_k final: ne[0]=%lld, ne[1]=%lld type=%d\n", pos_k->ne[0], pos_k->ne[1], pos_k->type); } if( !ggml_is_vector(pos_q) ) { @@ -7683,9 +7681,6 @@ struct llm_build_modern_bert : public llm_graph_context { pos_q = ggml_view_1d(ctx0, pos_q, n_el, 0); pos_q = ggml_cont(ctx0, pos_q); } - if( !ggml_is_vector(pos_q) ) { - } - // apply rope Qcur = ggml_rope_ext(ctx0, Qcur, pos_q, nullptr, @@ -7705,6 +7700,16 @@ struct llm_build_modern_bert : public llm_graph_context { // choseing mask, global vs swa ggml_tensor * kq_mask = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa; + // flatten K/V back to full embedding dim + int64_t n_embd = n_embd_head * n_head_kv; + int64_t n_tokens = Kcur->ne[2]; + + ggml_tensor *K_2d = ggml_reshape_2d(ctx0, Kcur, n_embd, n_tokens); + + ggml_tensor *K_flat = ggml_view_3d(ctx0, K_2d, n_embd, 1, n_tokens, + K_2d->nb[0], K_2d->nb[1], 0); + K_flat = ggml_cont(ctx0, K_flat); + ggml_tensor * V_flat = ggml_reshape_2d(ctx0, Vcur, n_embd, n_tokens); ggml_tensor * attn_out = build_attn( @@ -7712,8 +7717,8 @@ struct llm_build_modern_bert : public llm_graph_context { model.layers[il].wo, model.layers[il].bo, Qcur, - K_work, - V_work, + K_flat, + V_flat, kq_mask, nullptr, 1.0f / sqrtf(float(n_embd_head)), From 2bacfb0bc2b24215937ab32c859db2ccb3b446fb Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Thu, 11 Sep 2025 16:37:18 -0400 Subject: [PATCH 22/26] alternating rope implemented and modern bert graph build succeeds --- src/llama-model.cpp | 245 ++++++++++++++------------------------------ 1 file changed, 78 insertions(+), 167 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 34cd49083be51..6aa1426a2885f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7551,209 +7551,117 @@ struct llm_build_bert : public llm_graph_context { }; struct llm_build_modern_bert : public llm_graph_context { - llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) - : llm_graph_context(params) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_head = hparams.n_head(); - const int64_t n_head_kv = hparams.n_head_kv(); + llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - const int64_t n_local_swa = hparams.n_swa; - const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - // rope params - const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; - const int32_t n_rot = hparams.n_rot; - const int32_t n_ctx_orig = hparams.n_ctx_train; - const float freq_base = hparams.rope_freq_base_train; - const float freq_scale = hparams.rope_freq_scale_train; - const float attn_factor = 1.0f; - const float ext_factor = 1.0f; - const float beta_fast = 0.0f; - const float beta_slow = 0.0f; - - - ggml_tensor *inp_pos_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 4096, 1); - ggml_set_input(inp_pos_global); - size_t element_size = ggml_type_size(inp_pos_global->type); - - size_t nb1 = element_size; - size_t nb2 = nb1; + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = build_inp_pos(); // Initialize inp_pos with build_inp_pos() - inp_pos_global = ggml_view_3d(ctx0, inp_pos_global, 1, 1, 4096, nb1, nb2, 0); - inp_pos_global = ggml_cont(ctx0, inp_pos_global); - - ggml_tensor * inpL = build_inp_embd(model.tok_embd); + // construct input embeddings (token, type, position) + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); - if (model.type_embd) { - inpL = ggml_add(ctx0, inpL, ggml_view_1d(ctx0, model.type_embd, n_embd, 0)); - } - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + // embed layer norm + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); + cb(inpL, "inp_norm", -1); auto * inp_attn = build_attn_inp_kv_unified_iswa(); - ggml_tensor * inp_out_ids = build_inp_out_ids(); - + // iterate layers for (int il = 0; il < n_layer; ++il) { - LLAMA_LOG_INFO("Setting layer %d\n", il); - ggml_tensor * x = inpL; + ggml_tensor * cur = inpL; - // pre attn LayerNorm - ggml_tensor * x_attn_in = x; - if (model.layers[il].attn_norm) { - x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il); - } + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; - // fused QKV - GGML_ASSERT(model.layers[il].wqkv); - ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in); - if (model.layers[il].bqkv) { - qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv); + float rope_theta = il % 3 == 0 ? hparams.rope_freq_base_train : hparams.rope_freq_base_train_swa; + + // attention layer norm + if (model.layers[il].attn_norm) { + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); } - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0)); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd + n_embd_gqa)); + // self attention + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); - // optional q/k LayerNorm - if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); - if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - // reshape for multi-head attention Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // global or local layer - bool is_global = ((il + 1) % 3 == 0); - float freq_base_l = is_global ? 160000.0f : 10000.0f; // rope theta - float freq_scale_l = 1.0f; - - ggml_tensor * pos_q = inp_pos_global; - - ggml_tensor * K_work = Kcur; - ggml_tensor * V_work = Vcur; - ggml_tensor * pos_k = inp_pos_global; - - if (!is_global) { - ggml_tensor * idx_src = inp_attn->self_k_idxs_swa; - - ggml_tensor * idx_view1d = ggml_view_1d(ctx0, idx_src, idx_src->ne[0], 0); - ggml_tensor * idx_cont = ggml_cont(ctx0, idx_view1d); - - ggml_tensor * idx_i32 = idx_cont; - if (idx_i32->type != GGML_TYPE_I32) { - idx_i32 = ggml_cast(ctx0, idx_cont, GGML_TYPE_I32); - } - - const int64_t n_indices = idx_i32->ne[0]; - ggml_tensor * idx_2d = ggml_view_2d(ctx0, idx_i32, 1, n_indices, sizeof(int32_t), 0); - - idx_2d = ggml_cont(ctx0, idx_2d); - if (idx_2d->type != GGML_TYPE_I32) idx_2d = ggml_cast(ctx0, idx_2d, GGML_TYPE_I32); - - K_work = ggml_get_rows(ctx0, Kcur, idx_2d); - V_work = ggml_get_rows(ctx0, Vcur, idx_2d); - - ggml_tensor * pos_rows = ggml_get_rows(ctx0, inp_pos_global, idx_2d); - - if (!ggml_is_vector(pos_rows)) { - const int64_t n_el = ggml_nelements(pos_rows); - pos_rows = ggml_view_1d(ctx0, pos_rows, n_el, 0); - pos_rows = ggml_cont(ctx0, pos_rows); - } else { - pos_rows = ggml_cont(ctx0, pos_rows); - } - // ensure I32 - if (pos_rows->type != GGML_TYPE_I32) { - pos_rows = ggml_cast(ctx0, pos_rows, GGML_TYPE_I32); - } + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - // final pos_k to pass to rope - pos_k = pos_rows; - } + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - if( !ggml_is_vector(pos_q) ) { - const int64_t n_el = ggml_nelements(pos_q); - pos_q = ggml_view_1d(ctx0, pos_q, n_el, 0); - pos_q = ggml_cont(ctx0, pos_q); - } + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); - // apply rope - Qcur = ggml_rope_ext(ctx0, Qcur, pos_q, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow); + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); - if( !ggml_is_vector(pos_k) ) { - const int64_t n_el = ggml_nelements(pos_k); - pos_k = ggml_view_1d(ctx0, pos_k, n_el, 0); - pos_k = ggml_cont(ctx0, pos_k); + if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } - K_work = ggml_rope_ext(ctx0, K_work, pos_k, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow); - - // choseing mask, global vs swa - ggml_tensor * kq_mask = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa; - - // flatten K/V back to full embedding dim - int64_t n_embd = n_embd_head * n_head_kv; - int64_t n_tokens = Kcur->ne[2]; - - ggml_tensor *K_2d = ggml_reshape_2d(ctx0, Kcur, n_embd, n_tokens); - - ggml_tensor *K_flat = ggml_view_3d(ctx0, K_2d, n_embd, 1, n_tokens, - K_2d->nb[0], K_2d->nb[1], 0); - K_flat = ggml_cont(ctx0, K_flat); - ggml_tensor * V_flat = ggml_reshape_2d(ctx0, Vcur, n_embd, n_tokens); - - - ggml_tensor * attn_out = build_attn( - inp_attn, - model.layers[il].wo, - model.layers[il].bo, - Qcur, - K_flat, - V_flat, - kq_mask, - nullptr, - 1.0f / sqrtf(float(n_embd_head)), - il - ); + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); - - ggml_tensor * cur_attn = ggml_add(ctx0, x, attn_out); + // attention layer norm + cur = build_norm(cur, model.layers[il].attn_out_norm, nullptr, LLM_NORM, il); - // optional output select - if (il == n_layer - 1 && inp_out_ids) { - cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } + ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); - // pre mlp layer norm - ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il); + cur = build_ffn(cur, + model.layers[il].ffn_up, + NULL, NULL, NULL, NULL, NULL, + model.layers[il].ffn_down, + NULL, NULL, NULL, + LLM_FFN_GEGLU, LLM_FFN_SEQ, il); - // geglu ffn - ggml_tensor * mlp_out = build_ffn( - h, - model.layers[il].ffn_up, NULL, NULL, - NULL, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_GEGLU, LLM_FFN_PAR, il - ); + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); - // resudi addition after FFN - inpL = ggml_add(ctx0, mlp_out, cur_attn); + // input for next layer + inpL = cur; } + cur = inpL; - ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM, -1); + + cb(cur, "result_embd", -1); res->t_embd = cur; + ggml_build_forward_expand(gf, cur); } }; @@ -18450,6 +18358,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_MODERN_BERT: { llm = std::make_unique(*this, params); + LLAMA_LOG_INFO("Built llm\n"); } break; case LLM_ARCH_NEO_BERT: { @@ -18768,6 +18677,8 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { GGML_ABORT("fatal error"); } + LLAMA_LOG_INFO("Building pooling\n"); + // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); From 4e7c8793ae118783cbd390f9dceea8963e019c39 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Thu, 11 Sep 2025 16:41:04 -0400 Subject: [PATCH 23/26] fixed asser for equal ubatch seq --- src/llama-graph.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9ca2e579d7299..4bd41d0ed623d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1375,9 +1375,7 @@ ggml_tensor * llm_graph_context::build_attn( // [TAG_NO_CACHE_PAD] // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams - if (ubatch.n_seqs > 1) { - assert(!ubatch.equal_seqs()); - } + assert(!ubatch.equal_seqs()); ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; @@ -1547,7 +1545,6 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks( // optionally store to KV cache if (k_cur) { const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs(); - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); } From 20d448a8d72145d46c9af4b745790f896737f909 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Thu, 11 Sep 2025 16:42:41 -0400 Subject: [PATCH 24/26] cleanup --- src/llama-model.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6aa1426a2885f..58b885c9cf548 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18358,7 +18358,6 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_MODERN_BERT: { llm = std::make_unique(*this, params); - LLAMA_LOG_INFO("Built llm\n"); } break; case LLM_ARCH_NEO_BERT: { @@ -18677,8 +18676,6 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { GGML_ABORT("fatal error"); } - LLAMA_LOG_INFO("Building pooling\n"); - // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); From db4f5656e44468a6d8417e90b363f8b8e20583d1 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Fri, 12 Sep 2025 11:45:02 -0400 Subject: [PATCH 25/26] added mask check in vocab --- src/llama-model.cpp | 3 ++- src/llama-vocab.cpp | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 58b885c9cf548..7a05491868cb6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7550,6 +7550,7 @@ struct llm_build_bert : public llm_graph_context { } }; +template struct llm_build_modern_bert : public llm_graph_context { llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -18357,7 +18358,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_MODERN_BERT: { - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); } break; case LLM_ARCH_NEO_BERT: { diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 32683879083e0..00fbe4db1de2d 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2487,6 +2487,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { for (const auto * token : {"", "", "<|endoftext|>"}) { _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false); } + } else if ( _contains_any(model_name, {"modern-bert"})) { + if ( token_to_id.count("MASK") == 0 ) { + LLAMA_LOG_WARN("%s: Mask token missing in vocab!\n", __func__); + } + else { + _set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true); + } } } } From da0604a5487ffd830ec85616f0c6005462d8a913 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Fri, 12 Sep 2025 16:50:15 -0400 Subject: [PATCH 26/26] fixed alternating rope, the hparams.rope_freq_base_train and hparams.rope_freq_base_train_swa were the same and i set them to correct values --- src/llama-model.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7a05491868cb6..448c7320ac77e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -763,7 +763,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_LOCAL; hparams.set_swa_pattern(3, 0); - hparams.n_swa = 128; + hparams.rope_freq_base_train_swa = 10000.f; + hparams.rope_freq_base_train = 160000.f; + hparams.n_swa = 128; ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -7553,8 +7555,10 @@ struct llm_build_bert : public llm_graph_context { template struct llm_build_modern_bert : public llm_graph_context { llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + const float rope_theta_global = hparams.rope_freq_base_train; + const float rope_theta_local = hparams.rope_freq_base_train_swa; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7580,7 +7584,7 @@ struct llm_build_modern_bert : public llm_graph_context { ggml_tensor * Kcur; ggml_tensor * Vcur; - float rope_theta = il % 3 == 0 ? hparams.rope_freq_base_train : hparams.rope_freq_base_train_swa; + const float rope_theta = il % 3 == 0 ? rope_theta_global : rope_theta_local; // attention layer norm if (model.layers[il].attn_norm) {