Skip to content

Commit d81ce04

Browse files
committed
Add support for Llada-8b: diffusion model
llama: fix llama-model fixup working
1 parent b172309 commit d81ce04

File tree

12 files changed

+972
-43
lines changed

12 files changed

+972
-43
lines changed

common/arg.cpp

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3431,34 +3431,59 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34313431
}
34323432
).set_examples({LLAMA_EXAMPLE_SERVER}));
34333433

3434-
// diffusion parameters
3434+
// shared diffusion parameters
34353435
add_opt(common_arg(
34363436
{ "--diffusion-steps" }, "N",
3437-
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
3438-
[](common_params & params, int value) { params.diffusion.steps = value; }
3439-
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3437+
string_format("number of diffusion steps (default: %d)", params.diffusion_dream.steps),
3438+
[](common_params & params, int value) {
3439+
params.diffusion_dream.steps = value;
3440+
params.diffusion_llada.steps = value;
3441+
}
3442+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_DREAM, LLAMA_EXAMPLE_DIFFUSION_LLADA }));
3443+
add_opt(common_arg(
3444+
{ "--diffusion-visual" },
3445+
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
3446+
params.diffusion_dream.visual_mode ? "true" : "false"),
3447+
[](common_params & params) {
3448+
params.diffusion_dream.visual_mode = true;
3449+
params.diffusion_llada.visual_mode = true;
3450+
}
3451+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_DREAM, LLAMA_EXAMPLE_DIFFUSION_LLADA }));
3452+
3453+
// DREAM-specific diffusion parameters
34403454
add_opt(common_arg(
34413455
{ "--diffusion-eps" }, "F",
3442-
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
3443-
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
3444-
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3456+
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion_dream.eps),
3457+
[](common_params & params, const std::string & value) { params.diffusion_dream.eps = std::stof(value); }
3458+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_DREAM }));
34453459
add_opt(common_arg(
34463460
{ "--diffusion-algorithm" }, "N",
34473461
string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
3448-
params.diffusion.algorithm),
3449-
[](common_params & params, int value) { params.diffusion.algorithm = value; }
3450-
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3462+
params.diffusion_dream.algorithm),
3463+
[](common_params & params, int value) { params.diffusion_dream.algorithm = value; }
3464+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_DREAM }));
34513465
add_opt(common_arg(
34523466
{ "--diffusion-alg-temp" }, "F",
3453-
string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
3454-
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
3455-
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3456-
add_opt(common_arg(
3457-
{ "--diffusion-visual" },
3458-
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
3459-
params.diffusion.visual_mode ? "true" : "false"),
3460-
[](common_params & params) { params.diffusion.visual_mode = true; }
3461-
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
3467+
string_format("algorithm temperature (default: %.3f)", (double) params.diffusion_dream.alg_temp),
3468+
[](common_params & params, const std::string & value) { params.diffusion_dream.alg_temp = std::stof(value); }
3469+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_DREAM }));
3470+
3471+
// LLADA-specific diffusion parameters
3472+
add_opt(common_arg(
3473+
{ "--diffusion-block-length" }, "N",
3474+
string_format("block length for generation (default: %d)", params.diffusion_llada.block_length),
3475+
[](common_params & params, int value) { params.diffusion_llada.block_length = value; }
3476+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_LLADA }));
3477+
add_opt(common_arg(
3478+
{ "--diffusion-cfg-scale" }, "F",
3479+
string_format("classifier-free guidance scale (default: %.3f)", (double) params.diffusion_llada.cfg_scale),
3480+
[](common_params & params, const std::string & value) { params.diffusion_llada.cfg_scale = std::stof(value); }
3481+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_LLADA }));
3482+
add_opt(common_arg(
3483+
{ "--diffusion-remasking-alg" }, "N",
3484+
string_format("remasking algorithm: 0=LOW_CONFIDENCE, 1=RANDOM (default: %d)", params.diffusion_llada.remasking),
3485+
[](common_params & params, int value) { params.diffusion_llada.remasking = value; }
3486+
).set_examples({ LLAMA_EXAMPLE_DIFFUSION_LLADA }));
34623487

34633488
return ctx_arg;
34643489
}

common/common.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ enum llama_example {
8181
LLAMA_EXAMPLE_LOOKUP,
8282
LLAMA_EXAMPLE_PARALLEL,
8383
LLAMA_EXAMPLE_TTS,
84-
LLAMA_EXAMPLE_DIFFUSION,
84+
LLAMA_EXAMPLE_DIFFUSION_DREAM,
85+
LLAMA_EXAMPLE_DIFFUSION_LLADA,
8586

8687
LLAMA_EXAMPLE_COUNT,
8788
};
@@ -219,14 +220,22 @@ struct common_params_vocoder {
219220
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
220221
};
221222

222-
struct common_params_diffusion {
223+
struct common_params_diffusion_dream {
223224
int32_t steps = 64; // number of diffusion steps
224225
float eps = 1e-3f; // epsilon for timesteps
225226
int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
226227
float alg_temp = 0.0f; // algorithm temperature
227228
bool visual_mode = false; // show progressive diffusion on screen
228229
};
229230

231+
struct common_params_diffusion_llada {
232+
int32_t steps = 64; // number of diffusion steps
233+
int32_t block_length = 32; // block length for generation
234+
float cfg_scale = 0.2f; // classifier-free guidance scale
235+
int32_t remasking = 0; // remasking algorithm: 0=LOW_CONFIDENCE, 1=RANDOM
236+
bool visual_mode = false; // show progressive diffusion on screen
237+
};
238+
230239
enum common_reasoning_format {
231240
COMMON_REASONING_FORMAT_NONE,
232241
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
@@ -277,8 +286,9 @@ struct common_params {
277286

278287
struct common_params_sampling sampling;
279288
struct common_params_speculative speculative;
280-
struct common_params_vocoder vocoder;
281-
struct common_params_diffusion diffusion;
289+
struct common_params_vocoder vocoder;
290+
struct common_params_diffusion_dream diffusion_dream;
291+
struct common_params_diffusion_llada diffusion_llada;
282292

283293
struct common_params_model model;
284294

convert_hf_to_gguf.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,6 +2851,159 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
28512851
yield from super().modify_tensors(data_torch, name, bid)
28522852

28532853

2854+
@ModelBase.register("LLaDAModelLM")
2855+
class LLaDAModel(TextModel):
2856+
model_arch = gguf.MODEL_ARCH.LLADA
2857+
undo_permute = True
2858+
2859+
def __init__(self, *args, **kwargs):
2860+
super().__init__(*args, **kwargs)
2861+
# fix for SmolVLM2, missing `num_attention_heads` in config.json
2862+
if self.hf_arch == "VLlama3ForCausalLM":
2863+
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
2864+
2865+
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
2866+
tokens: list[str] = []
2867+
toktypes: list[int] = []
2868+
2869+
from transformers import AutoTokenizer
2870+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
2871+
2872+
vocab_dict = tokenizer.get_vocab()
2873+
vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
2874+
assert max(vocab_dict.values()) < vocab_size
2875+
2876+
tokpre = self.get_vocab_base_pre(tokenizer)
2877+
2878+
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
2879+
added_vocab = tokenizer.get_added_vocab()
2880+
2881+
for i in range(vocab_size):
2882+
if i not in reverse_vocab:
2883+
tokens.append(f"[PAD{i}]")
2884+
toktypes.append(gguf.TokenType.UNUSED)
2885+
elif reverse_vocab[i] in added_vocab:
2886+
tokens.append(reverse_vocab[i])
2887+
# Check if it's a special token - treat special tokens as CONTROL tokens
2888+
if hasattr(tokenizer, 'added_tokens_decoder') and i in tokenizer.added_tokens_decoder:
2889+
if tokenizer.added_tokens_decoder[i].special:
2890+
toktypes.append(gguf.TokenType.CONTROL)
2891+
else:
2892+
toktypes.append(gguf.TokenType.USER_DEFINED)
2893+
else:
2894+
# Fallback: treat all added vocab as control tokens for special tokens like <|im_start|>
2895+
toktypes.append(gguf.TokenType.CONTROL)
2896+
else:
2897+
tokens.append(reverse_vocab[i])
2898+
toktypes.append(gguf.TokenType.NORMAL)
2899+
2900+
return tokens, toktypes, tokpre
2901+
2902+
def set_vocab(self):
2903+
try:
2904+
self._set_vocab_sentencepiece()
2905+
except FileNotFoundError:
2906+
try:
2907+
self._set_vocab_llama_hf()
2908+
except (FileNotFoundError, TypeError):
2909+
# Llama 3
2910+
self._set_vocab_gpt2()
2911+
2912+
# Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256)
2913+
if self.hparams.get("vocab_size", 32000) == 32016:
2914+
special_vocab = gguf.SpecialVocab(
2915+
self.dir_model, load_merges=False,
2916+
special_token_types = ['prefix', 'suffix', 'middle', 'eot']
2917+
)
2918+
special_vocab._set_special_token("prefix", 32007)
2919+
special_vocab._set_special_token("suffix", 32008)
2920+
special_vocab._set_special_token("middle", 32009)
2921+
special_vocab._set_special_token("eot", 32010)
2922+
special_vocab.add_to_gguf(self.gguf_writer)
2923+
2924+
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
2925+
if tokenizer_config_file.is_file():
2926+
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
2927+
tokenizer_config_json = json.load(f)
2928+
if "add_prefix_space" in tokenizer_config_json:
2929+
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
2930+
2931+
# Apply to granite small models only
2932+
if self.hparams.get("vocab_size", 32000) == 49152:
2933+
self.gguf_writer.add_add_bos_token(False)
2934+
2935+
def set_gguf_parameters(self):
2936+
super().set_gguf_parameters()
2937+
self._try_set_pooling_type()
2938+
2939+
# Add parameters similar to LlamaModel
2940+
hparams = self.hparams
2941+
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
2942+
2943+
if (rope_dim := hparams.get("head_dim")) is None:
2944+
n_heads = hparams.get("num_attention_heads", hparams.get("n_heads"))
2945+
rope_dim = hparams.get("hidden_size", hparams.get("d_model")) // n_heads
2946+
self.gguf_writer.add_rope_dimension_count(rope_dim)
2947+
2948+
# Set context length for LLaDA
2949+
context_length = self.hparams.get("max_sequence_length")
2950+
self.gguf_writer.add_context_length(context_length)
2951+
2952+
# Set embedding length (dimension size)
2953+
embedding_length = self.hparams.get("d_model")
2954+
self.gguf_writer.add_embedding_length(embedding_length)
2955+
2956+
# Set feed forward length (MLP hidden size)
2957+
feed_forward_length = self.hparams.get("mlp_hidden_size")
2958+
self.gguf_writer.add_feed_forward_length(feed_forward_length)
2959+
2960+
# Set RoPE parameters
2961+
if "rope_theta" in self.hparams:
2962+
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
2963+
2964+
# Set RMS norm epsilon
2965+
if "rms_norm_eps" in self.hparams:
2966+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
2967+
2968+
# LLaDA models use non-causal attention for diffusion, similar to Dream
2969+
self.gguf_writer.add_causal_attention(False)
2970+
# Handle RoPE scaling similar to LlamaModel and Dream
2971+
rope_scaling = self.hparams.get("rope_scaling") or {}
2972+
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
2973+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
2974+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
2975+
elif rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
2976+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
2977+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
2978+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
2979+
2980+
# Add LLaDA-specific parameters
2981+
mask_token_id = self.hparams.get("mask_token_id")
2982+
if mask_token_id is not None:
2983+
self.gguf_writer.add_mask_token_id(mask_token_id)
2984+
2985+
@staticmethod
2986+
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
2987+
if n_head_kv is not None and n_head != n_head_kv:
2988+
n_head = n_head_kv
2989+
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
2990+
.swapaxes(1, 2)
2991+
.reshape(weights.shape))
2992+
2993+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2994+
n_head = self.hparams.get("num_attention_heads", self.hparams.get("n_heads"))
2995+
n_kv_head = self.hparams.get("num_key_value_heads", self.hparams.get("n_kv_heads"))
2996+
2997+
if self.undo_permute:
2998+
if name.endswith(("q_proj.weight", "q_proj.bias")):
2999+
data_torch = LLaDAModel.permute(data_torch, n_head, n_head)
3000+
if name.endswith(("k_proj.weight", "k_proj.bias")):
3001+
data_torch = LLaDAModel.permute(data_torch, n_head, n_kv_head)
3002+
3003+
# LLaDA model tensors should be mapped directly since it's the base model
3004+
yield from super().modify_tensors(data_torch, name, bid)
3005+
3006+
28543007
@ModelBase.register("Ernie4_5_ForCausalLM")
28553008
class Ernie4_5Model(TextModel):
28563009
model_arch = gguf.MODEL_ARCH.ERNIE4_5

examples/diffusion/CMakeLists.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1-
set(TARGET llama-diffusion-cli)
2-
add_executable(${TARGET} diffusion-cli.cpp)
1+
set(TARGET llama-diffusion-dream-cli)
2+
add_executable(${TARGET} diffusion-dream-cli.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
6+
7+
set(TARGET llama-diffusion-llada-cli)
8+
add_executable(${TARGET} diffusion-llada-cli.cpp)
39
install(TARGETS ${TARGET} RUNTIME)
410
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
511
target_compile_features(${TARGET} PRIVATE cxx_std_17)

examples/diffusion/diffusion-cli.cpp renamed to examples/diffusion/diffusion-dream-cli.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,9 @@ static std::string format_input_text(const std::string & prompt, bool use_chat_t
332332
}
333333

334334
struct callback_data {
335-
const common_params_diffusion * diff_params;
336-
const llama_vocab * vocab;
337-
int32_t n_input;
335+
const common_params_diffusion_dream * diff_params;
336+
const llama_vocab * vocab;
337+
int32_t n_input;
338338
};
339339

340340
static bool diffusion_step_callback(int32_t step,
@@ -396,13 +396,13 @@ int main(int argc, char ** argv) {
396396

397397
common_params params;
398398

399-
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
399+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION_DREAM)) {
400400
return 1;
401401
}
402402

403403
const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
404-
const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
405-
alg_names[params.diffusion.algorithm] :
404+
const char * alg_name = (params.diffusion_dream.algorithm >= 0 && params.diffusion_dream.algorithm <= 3) ?
405+
alg_names[params.diffusion_dream.algorithm] :
406406
"UNKNOWN";
407407

408408
common_init();
@@ -421,6 +421,11 @@ int main(int argc, char ** argv) {
421421
return 1;
422422
}
423423

424+
// Check if the model architecture is Dream
425+
char arch_str[128];
426+
GGML_ASSERT(llama_model_meta_val_str(model, "general.architecture", arch_str, 128) >= 0 &&
427+
std::string(arch_str) == "dream");
428+
424429
llama_context_params ctx_params = llama_context_default_params();
425430
ctx_params.n_ctx = params.n_ctx;
426431
ctx_params.n_batch = params.n_batch;
@@ -445,7 +450,7 @@ int main(int argc, char ** argv) {
445450
std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
446451
/*add special tokens*/ true,
447452
/*parse special*/ true);
448-
int n_input = input_tokens.size();
453+
int n_input = input_tokens.size();
449454

450455
if (n_input >= params.n_ctx) {
451456
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
@@ -455,28 +460,28 @@ int main(int argc, char ** argv) {
455460
}
456461

457462
struct diffusion_params ldiff_params = diffusion_default_params();
458-
ldiff_params.steps = params.diffusion.steps;
459-
ldiff_params.eps = params.diffusion.eps;
463+
ldiff_params.steps = params.diffusion_dream.steps;
464+
ldiff_params.eps = params.diffusion_dream.eps;
460465
ldiff_params.temperature = params.sampling.temp;
461466
ldiff_params.top_p = params.sampling.top_p;
462467
ldiff_params.top_k = params.sampling.top_k;
463-
ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
464-
ldiff_params.alg_temp = params.diffusion.alg_temp;
468+
ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion_dream.algorithm);
469+
ldiff_params.alg_temp = params.diffusion_dream.alg_temp;
465470
ldiff_params.seed = params.sampling.seed;
466471

467472
llama_token mask_token_id = llama_vocab_mask(vocab);
468473
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
469474

470-
LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
471-
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps);
472-
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps);
473-
LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm,
475+
LOG_INF("dream_diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
476+
LOG_INF("dream_diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion_dream.steps);
477+
LOG_INF("dream_diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion_dream.eps);
478+
LOG_INF("dream_diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion_dream.algorithm,
474479
alg_name);
475-
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp);
480+
LOG_INF("dream_diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion_dream.alg_temp);
476481

477482
ldiff_params.mask_token_id = mask_token_id;
478483

479-
callback_data cb_data = { &params.diffusion, vocab, n_input };
484+
callback_data cb_data = { &params.diffusion_dream, vocab, n_input };
480485

481486
ldiff_params.step_callback = diffusion_step_callback;
482487
ldiff_params.step_callback_user_data = &cb_data;
@@ -488,7 +493,7 @@ int main(int argc, char ** argv) {
488493
ldiff_params, n_generated);
489494

490495
if (n_generated > 0) {
491-
if (params.diffusion.visual_mode) {
496+
if (params.diffusion_dream.visual_mode) {
492497
//clear screen and move cursor to top-left
493498
LOG_INF("\033[2J\033[H");
494499
}

0 commit comments

Comments
 (0)