Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 115 additions & 35 deletions src/infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,17 +645,47 @@ inline float clip(float x, float v) {
return x < -v ? -v : (x > v ? v : x);
}

static void rope(float* buf, float* vec, int d, int head_dim, int pos, float theta) {
inline float yarn_find_correction_dim(int num_rotations, int dim, int base, int max_pos) {
return (dim * logf(max_pos / (num_rotations * 2 * M_PI))) / (2 * logf(base));
}

inline float yarn_linear_ramp(float min, float max, int i) {
if (min == max) {
max += 0.001; // Prevent singularity
}
float lin = (i - min) / (max - min);
return std::clamp(lin, 0.0f, 1.0f);
}

inline float yarn_get_mscale(float scale, float mscale) {
return scale <= 1.0f ? 1.0f : 0.1 * mscale * logf(scale) + 1.0f;
}

static void rope(
float* buf, float* vec,
int d, int pos,
float theta, float scaling_factor,
int beta_fast, int beta_slow,
float mscale, float mscale_all_dim,
int original_max_pos
) {
// For some reason, DeepSeek-V2 was trained using rope output
// layout transposed compared to the input. This means we need a buffer
// to hold intermediate results.
assert(d % 2 == 0);
for (int i = 0; i < d; i += 2) {
int j_head = i % head_dim;
float freq = 1.0f / powf(theta, (float)j_head / (float)head_dim);
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
// TODO: cache these values
float freq_extra = 1.0f / powf(theta, (float)i / (float)d);
float freq_inter = freq_extra / scaling_factor;
float low = std::max(0.0f, floorf(yarn_find_correction_dim(beta_fast, d, theta, original_max_pos)));
float high = std::min(d - 1.0f, ceilf(yarn_find_correction_dim(beta_slow, d, theta, original_max_pos)));
float inv_freq_mask = 1.0f - yarn_linear_ramp(low, high, i / 2);
float inv_freq = freq_inter * (1.0f - inv_freq_mask) + freq_extra * inv_freq_mask;
float val = pos * inv_freq;
float _mscale =
yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(scaling_factor, mscale_all_dim);
float fcr = _mscale * cosf(val);
float fci = _mscale * sinf(val);

float v0 = vec[i];
float v1 = vec[i + 1];
Expand All @@ -667,12 +697,9 @@ static void rope(float* buf, float* vec, int d, int head_dim, int pos, float the
}
}

static void rope_v3(float* vec, int d, int head_dim, int pos, float theta) {
int rotary_dim = head_dim;

static void rope_v3(float* vec, int d, int pos, float theta) {
for (int i = 0; i < d; i += 2) {
int j_head = i % head_dim;
float freq = j_head >= rotary_dim ? 0.f : 1.0f / powf(theta, (float)j_head / (float)rotary_dim);
float freq = 1.0f / powf(theta, (float)i / (float)d);
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
Expand All @@ -684,17 +711,31 @@ static void rope_v3(float* vec, int d, int head_dim, int pos, float theta) {
}
}

static void rope(float* buf, f16_t* vec, int d, int head_dim, int pos, float theta) {
static void rope(
float* buf, f16_t* vec,
int d, int pos,
float theta, float scaling_factor,
int beta_fast, int beta_slow,
float mscale, float mscale_all_dim,
int original_max_pos
) {
// For some reason, DeepSeek-V2 was trained using rope output
// layout transposed compared to the input. This means we need a buffer
// to hold intermediate results.
assert(d % 2 == 0);
for (int i = 0; i < d; i += 2) {
int j_head = i % head_dim;
float freq = 1.0f / powf(theta, (float)j_head / (float)head_dim);
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
// TODO: cache these values
float freq_extra = 1.0f / powf(theta, (float)i / (float)d);
float freq_inter = freq_extra / scaling_factor;
float low = std::max(0.0f, floorf(yarn_find_correction_dim(beta_fast, d, theta, original_max_pos)));
float high = std::min(d - 1.0f, ceilf(yarn_find_correction_dim(beta_slow, d, theta, original_max_pos)));
float inv_freq_mask = 1.0f - yarn_linear_ramp(low, high, i / 2);
float inv_freq = freq_inter * (1.0f - inv_freq_mask) + freq_extra * inv_freq_mask;
float val = pos * inv_freq;
float _mscale =
yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(scaling_factor, mscale_all_dim);
float fcr = _mscale * cosf(val);
float fci = _mscale * sinf(val);

float v0 = half_to_float(vec[i]);
float v1 = half_to_float(vec[i + 1]);
Expand All @@ -706,12 +747,9 @@ static void rope(float* buf, f16_t* vec, int d, int head_dim, int pos, float the
}
}

static void rope_v3(f16_t* vec, int d, int head_dim, int pos, float theta) {
int rotary_dim = head_dim;

static void rope_v3(f16_t* vec, int d, int pos, float theta) {
for (int i = 0; i < d; i += 2) {
int j_head = i % head_dim;
float freq = j_head >= rotary_dim ? 0.f : 1.0f / powf(theta, (float)j_head / (float)rotary_dim);
float freq = 1.0f / powf(theta, (float)i / (float)d);
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
Expand Down Expand Up @@ -958,17 +996,31 @@ void BlockMHA::_attention_impl(
bool is_v3 = c.has_moegate_bias;
for (int h = 0; h < c.n_heads; h++) {
if (is_v3) {
rope_v3(s.q(h) + q_pe_offset, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
rope_v3(s.q(h) + q_pe_offset, c.qk_rope_head_dim, pos, c.rope_theta);
} else {
rope(s.ropebuf(), s.q(h) + q_pe_offset, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
rope(
s.ropebuf(), s.q(h) + q_pe_offset,
c.qk_rope_head_dim, pos,
c.rope_theta, c.rs_factor,
c.rs_beta_fast, c.rs_beta_slow,
c.rs_mscale, c.rs_mscale_all_dim,
c.rs_original_max_pos
);
}
}
int kv_pe_offset = c.kv_lora_rank;
float* k_rope = s.kv_a() + kv_pe_offset;
if (is_v3) {
rope_v3(k_rope, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
rope_v3(k_rope, c.qk_rope_head_dim, pos, c.rope_theta);
} else {
rope(s.ropebuf(), k_rope, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
rope(
s.ropebuf(), k_rope,
c.qk_rope_head_dim, pos,
c.rope_theta, c.rs_factor,
c.rs_beta_fast, c.rs_beta_slow,
c.rs_mscale, c.rs_mscale_all_dim,
c.rs_original_max_pos
);
}
// rms norm to non-pe chunk of kv_a
rmsnorm(s.kv_a(), s.kv_a(), this->rms_kv_a_weight(), c.kv_lora_rank, c.norm_eps);
Expand Down Expand Up @@ -1012,9 +1064,16 @@ void BlockMHA::_attention_impl(
for (int h = 0; h < c.n_heads; h++) {
f16_t* kh = key + h * c.head_dim;
if (is_v3) {
rope_v3(kh + q_pe_offset, c.qk_rope_head_dim, c.qk_rope_head_dim, 1, c.rope_theta);
rope_v3(kh + q_pe_offset, c.qk_rope_head_dim, 1, c.rope_theta);
} else {
rope(s.ropebuf(), kh + q_pe_offset, c.qk_rope_head_dim, c.qk_rope_head_dim, 1, c.rope_theta);
rope(
s.ropebuf(), kh + q_pe_offset,
c.qk_rope_head_dim, 1,
c.rope_theta, c.rs_factor,
c.rs_beta_fast, c.rs_beta_slow,
c.rs_mscale, c.rs_mscale_all_dim,
c.rs_original_max_pos
);
}
}
}
Expand Down Expand Up @@ -1073,17 +1132,31 @@ void BlockMLA::_attention_impl(
bool is_v3 = c.has_moegate_bias;
for (int h = 0; h < c.n_heads; h++) {
if (is_v3) {
rope_v3(s.q_rope(h), c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
rope_v3(s.q_rope(h), c.qk_rope_head_dim, pos, c.rope_theta);
} else {
rope(s.ropebuf(), s.q_rope(h), c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
rope(
s.ropebuf(), s.q_rope(h),
c.qk_rope_head_dim, pos,
c.rope_theta, c.rs_factor,
c.rs_beta_fast, c.rs_beta_slow,
c.rs_mscale, c.rs_mscale_all_dim,
c.rs_original_max_pos
);
}
}
int kv_pe_offset = c.kv_lora_rank;
float* k_rope = s.kv_a() + kv_pe_offset;
if (is_v3) {
rope_v3(k_rope, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
rope_v3(k_rope, c.qk_rope_head_dim, pos, c.rope_theta);
} else {
rope(s.ropebuf(), k_rope, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
rope(
s.ropebuf(), k_rope,
c.qk_rope_head_dim, pos,
c.rope_theta, c.rs_factor,
c.rs_beta_fast, c.rs_beta_slow,
c.rs_mscale, c.rs_mscale_all_dim,
c.rs_original_max_pos
);
}
// rms norm to non-pe chunk of kv_a (compressed latent kv)
rmsnorm(s.kv_a(), s.kv_a(), this->rms_kv_a_weight(), c.kv_lora_rank, c.norm_eps);
Expand All @@ -1103,9 +1176,16 @@ void BlockMLA::_attention_impl(
for (int r = 0; r < kv_sink; r++) {
f16_t* kv = this->kv_rope_cache(r);
if (is_v3) {
rope_v3(kv, c.qk_rope_head_dim, c.qk_rope_head_dim, 1, c.rope_theta);
rope_v3(kv, c.qk_rope_head_dim, 1, c.rope_theta);
} else {
rope(s.ropebuf(), kv, c.qk_rope_head_dim, c.qk_rope_head_dim, 1, c.rope_theta);
rope(
s.ropebuf(), kv,
c.qk_rope_head_dim, 1,
c.rope_theta, c.rs_factor,
c.rs_beta_fast, c.rs_beta_slow,
c.rs_mscale, c.rs_mscale_all_dim,
c.rs_original_max_pos
);
}
}

Expand Down Expand Up @@ -1271,7 +1351,7 @@ void Model::_forward_cpu(InferenceState& s, int token, int pos, InferenceMode mo
// When decoding past the context length, keep the first few tokens in the KV cache
// untouched as "attention sinks" while replacing the rest in ring order.
// See StreamingLLM (https://arxiv.org/pdf/2309.17453) for more.
int original_max_position = c.rs_original_max_position_embeddings;
int original_max_position = c.rs_original_max_pos;
int kv_sink = pos >= original_max_position ? KV_SINKS : 0;
int kv_pos = kv_sink + (pos - kv_sink) % (original_max_position - kv_sink);
int kv_len = pos >= original_max_position ? original_max_position : pos + 1;
Expand Down
2 changes: 1 addition & 1 deletion src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void Config::from_yalm(YALMData& yalm, int context) {
rs_factor = std::stof(yalm.metadata.at("rope_scaling_factor").get<std::string>());
rs_mscale = std::stof(yalm.metadata.at("rope_scaling_mscale").get<std::string>());
rs_mscale_all_dim = std::stof(yalm.metadata.at("rope_scaling_mscale_all_dim").get<std::string>());
rs_original_max_position_embeddings = std::stoi(yalm.metadata.at("rope_scaling_original_max_position_embeddings").get<std::string>());
rs_original_max_pos = std::stoi(yalm.metadata.at("rope_scaling_original_max_position_embeddings").get<std::string>());
}

std::optional<QTensor> check_tensor(const Tensor* tensor, Quant weight_quant, std::array<int, 4> shape, const int debug_line) {
Expand Down
2 changes: 1 addition & 1 deletion src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ struct Config {
float rs_factor;
float rs_mscale;
float rs_mscale_all_dim;
int rs_original_max_position_embeddings;
int rs_original_max_pos;

// If nonzero `context` is supplied, max sequence length is limited to `context`.
void from_yalm(YALMData& yalm, int context = 0);
Expand Down