From 832e32639f600197ad2bfbb2e5bf9c1479243a49 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Mar 2026 11:29:16 +0200 Subject: [PATCH] cont : rotate V more + refactor --- src/llama-graph.cpp | 63 ++++++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1b3370b6dc..8dfc92b718 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -52,13 +52,13 @@ static bool can_reuse_kq_mask( // impl -static bool is_power_of_2(int n) { +static bool ggml_is_power_of_2(int n) { return (n & (n - 1)) == 0; } // orthonormal Walsh-Hadamard rotation matrix static void set_input_hadamard(int n, float * data) { - assert(is_power_of_2(n)); + assert(ggml_is_power_of_2(n)); data[0*n + 0] = 1.0 / sqrtf(n); @@ -75,6 +75,20 @@ static void set_input_hadamard(int n, float * data) { } } +static ggml_tensor * ggml_rotate_hadamard( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat(ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -2076,17 +2090,23 @@ static std::unique_ptr build_attn_inp_kv_impl( } { + // I think we can afford to rotate the V more compared to Q and K + // ref: https://github.com/ggml-org/llama.cpp/pull/21038 + const bool can_rot = !hparams.is_n_embd_k_gqa_variable() && !hparams.is_n_embd_v_gqa_variable() && - is_power_of_2(hparams.n_embd_head_k()) && - is_power_of_2(hparams.n_embd_head_v()) && + ggml_is_power_of_2(hparams.n_embd_head_k()) && + //ggml_is_power_of_2(hparams.n_embd_head_v()) && + hparams.n_embd_head_v() % 64 == 0 && + hparams.n_embd_head_k() >= 64 && + hparams.n_embd_head_v() >= 64 && ggml_is_quantized(mctx_cur->type_k()) && ggml_is_quantized(mctx_cur->type_v()); if (can_rot) { const auto nk = hparams.n_embd_head_k(); - const auto nv = hparams.n_embd_head_v(); + const auto nv = 64; inp->self_rotk = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nk, nk); inp->self_rotv = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nv, nv); @@ -2125,9 +2145,9 @@ ggml_tensor * llm_graph_context::build_attn( GGML_ASSERT(v_mla == nullptr); if (inp->self_rotk) { - q_cur = ggml_mul_mat(ctx0, inp->self_rotk, q_cur); - k_cur = ggml_mul_mat(ctx0, inp->self_rotk, k_cur); - v_cur = ggml_mul_mat(ctx0, inp->self_rotv, v_cur); + q_cur = ggml_rotate_hadamard(ctx0, q_cur, inp->self_rotk); + k_cur = ggml_rotate_hadamard(ctx0, k_cur, inp->self_rotk); + v_cur = ggml_rotate_hadamard(ctx0, v_cur, inp->self_rotv); } // these nodes are added to the graph together so that they are not reordered @@ -2158,11 +2178,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (inp->self_rotv) { - const auto n = inp->self_rotv->ne[0]; - - cur = ggml_reshape_4d(ctx0, cur, n, cur->ne[0]/n, cur->ne[1], cur->ne[2]); - cur = ggml_mul_mat(ctx0, inp->self_rotv, cur); - cur = ggml_reshape_3d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2], cur->ne[3]); + cur = ggml_rotate_hadamard(ctx0, cur, inp->self_rotv); } if (wo) { @@ -2276,12 +2292,12 @@ ggml_tensor * llm_graph_context::build_attn( float kq_scale, int il) const { if (inp->self_rotk) { - q_cur = ggml_mul_mat(ctx0, inp->self_rotk, q_cur); + q_cur = ggml_rotate_hadamard(ctx0, q_cur, inp->self_rotk); if (k_cur) { - k_cur = ggml_mul_mat(ctx0, inp->self_rotk, k_cur); + k_cur = ggml_rotate_hadamard(ctx0, k_cur, inp->self_rotk); } if (v_cur) { - v_cur = ggml_mul_mat(ctx0, inp->self_rotv, v_cur); + v_cur = ggml_rotate_hadamard(ctx0, v_cur, inp->self_rotv); } } @@ -2326,11 +2342,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (inp->self_rotv) { - const auto n = inp->self_rotv->ne[0]; - - cur = ggml_reshape_4d(ctx0, cur, n, cur->ne[0]/n, cur->ne[1], cur->ne[2]); - cur = ggml_mul_mat(ctx0, inp->self_rotv, cur); - cur = ggml_reshape_3d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2], cur->ne[3]); + cur = ggml_rotate_hadamard(ctx0, cur, inp->self_rotv); } if (wo) { @@ -2441,14 +2453,17 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const const bool can_rot = !hparams.is_n_embd_k_gqa_variable() && !hparams.is_n_embd_v_gqa_variable() && - is_power_of_2(hparams.n_embd_head_k()) && - is_power_of_2(hparams.n_embd_head_v()) && + ggml_is_power_of_2(hparams.n_embd_head_k()) && + //ggml_is_power_of_2(hparams.n_embd_head_v()) && + hparams.n_embd_head_v() % 64 == 0 && + hparams.n_embd_head_k() >= 64 && + hparams.n_embd_head_v() >= 64 && ggml_is_quantized(mctx_cur->get_base()->type_k()) && ggml_is_quantized(mctx_cur->get_base()->type_v()); if (can_rot) { const auto nk = hparams.n_embd_head_k(); - const auto nv = hparams.n_embd_head_v(); + const auto nv = 64; inp->self_rotk = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nk, nk); inp->self_rotv = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nv, nv);