cont : rotate V more + refactor

This commit is contained in:
Georgi Gerganov
2026-03-27 11:29:16 +02:00
parent e5aa067d68
commit 832e32639f

View File

@@ -52,13 +52,13 @@ static bool can_reuse_kq_mask(
// impl
static bool is_power_of_2(int n) {
static bool ggml_is_power_of_2(int n) {
return (n & (n - 1)) == 0;
}
// orthonormal Walsh-Hadamard rotation matrix
static void set_input_hadamard(int n, float * data) {
assert(is_power_of_2(n));
assert(ggml_is_power_of_2(n));
data[0*n + 0] = 1.0 / sqrtf(n);
@@ -75,6 +75,20 @@ static void set_input_hadamard(int n, float * data) {
}
}
static ggml_tensor * ggml_rotate_hadamard(
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * rot) {
const auto n = rot->ne[0];
ggml_tensor * res;
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
res = ggml_mul_mat(ctx, rot, res);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
return res;
}
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens;
@@ -2076,17 +2090,23 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
}
{
// I think we can afford to rotate the V more compared to Q and K
// ref: https://github.com/ggml-org/llama.cpp/pull/21038
const bool can_rot =
!hparams.is_n_embd_k_gqa_variable() &&
!hparams.is_n_embd_v_gqa_variable() &&
is_power_of_2(hparams.n_embd_head_k()) &&
is_power_of_2(hparams.n_embd_head_v()) &&
ggml_is_power_of_2(hparams.n_embd_head_k()) &&
//ggml_is_power_of_2(hparams.n_embd_head_v()) &&
hparams.n_embd_head_v() % 64 == 0 &&
hparams.n_embd_head_k() >= 64 &&
hparams.n_embd_head_v() >= 64 &&
ggml_is_quantized(mctx_cur->type_k()) &&
ggml_is_quantized(mctx_cur->type_v());
if (can_rot) {
const auto nk = hparams.n_embd_head_k();
const auto nv = hparams.n_embd_head_v();
const auto nv = 64;
inp->self_rotk = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nk, nk);
inp->self_rotv = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nv, nv);
@@ -2125,9 +2145,9 @@ ggml_tensor * llm_graph_context::build_attn(
GGML_ASSERT(v_mla == nullptr);
if (inp->self_rotk) {
q_cur = ggml_mul_mat(ctx0, inp->self_rotk, q_cur);
k_cur = ggml_mul_mat(ctx0, inp->self_rotk, k_cur);
v_cur = ggml_mul_mat(ctx0, inp->self_rotv, v_cur);
q_cur = ggml_rotate_hadamard(ctx0, q_cur, inp->self_rotk);
k_cur = ggml_rotate_hadamard(ctx0, k_cur, inp->self_rotk);
v_cur = ggml_rotate_hadamard(ctx0, v_cur, inp->self_rotv);
}
// these nodes are added to the graph together so that they are not reordered
@@ -2158,11 +2178,7 @@ ggml_tensor * llm_graph_context::build_attn(
cb(cur, "kqv_out", il);
if (inp->self_rotv) {
const auto n = inp->self_rotv->ne[0];
cur = ggml_reshape_4d(ctx0, cur, n, cur->ne[0]/n, cur->ne[1], cur->ne[2]);
cur = ggml_mul_mat(ctx0, inp->self_rotv, cur);
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2], cur->ne[3]);
cur = ggml_rotate_hadamard(ctx0, cur, inp->self_rotv);
}
if (wo) {
@@ -2276,12 +2292,12 @@ ggml_tensor * llm_graph_context::build_attn(
float kq_scale,
int il) const {
if (inp->self_rotk) {
q_cur = ggml_mul_mat(ctx0, inp->self_rotk, q_cur);
q_cur = ggml_rotate_hadamard(ctx0, q_cur, inp->self_rotk);
if (k_cur) {
k_cur = ggml_mul_mat(ctx0, inp->self_rotk, k_cur);
k_cur = ggml_rotate_hadamard(ctx0, k_cur, inp->self_rotk);
}
if (v_cur) {
v_cur = ggml_mul_mat(ctx0, inp->self_rotv, v_cur);
v_cur = ggml_rotate_hadamard(ctx0, v_cur, inp->self_rotv);
}
}
@@ -2326,11 +2342,7 @@ ggml_tensor * llm_graph_context::build_attn(
cb(cur, "kqv_out", il);
if (inp->self_rotv) {
const auto n = inp->self_rotv->ne[0];
cur = ggml_reshape_4d(ctx0, cur, n, cur->ne[0]/n, cur->ne[1], cur->ne[2]);
cur = ggml_mul_mat(ctx0, inp->self_rotv, cur);
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2], cur->ne[3]);
cur = ggml_rotate_hadamard(ctx0, cur, inp->self_rotv);
}
if (wo) {
@@ -2441,14 +2453,17 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
const bool can_rot =
!hparams.is_n_embd_k_gqa_variable() &&
!hparams.is_n_embd_v_gqa_variable() &&
is_power_of_2(hparams.n_embd_head_k()) &&
is_power_of_2(hparams.n_embd_head_v()) &&
ggml_is_power_of_2(hparams.n_embd_head_k()) &&
//ggml_is_power_of_2(hparams.n_embd_head_v()) &&
hparams.n_embd_head_v() % 64 == 0 &&
hparams.n_embd_head_k() >= 64 &&
hparams.n_embd_head_v() >= 64 &&
ggml_is_quantized(mctx_cur->get_base()->type_k()) &&
ggml_is_quantized(mctx_cur->get_base()->type_v());
if (can_rot) {
const auto nk = hparams.n_embd_head_k();
const auto nv = hparams.n_embd_head_v();
const auto nv = 64;
inp->self_rotk = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nk, nk);
inp->self_rotv = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nv, nv);