From 211e58178a7d5a5d0ec38ad2865a4e055ff99698 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Apr 2026 18:27:15 +0300 Subject: [PATCH] wip --- include/llama.h | 29 --------- src/llama-context.cpp | 113 ++++----------------------------- src/llama-context.h | 3 +- src/llama-cparams.h | 4 +- src/llama-ext.h | 10 +++ src/llama-graph.cpp | 14 +++- src/llama-graph.h | 14 ++-- src/llama-hparams.cpp | 2 +- src/llama-hparams.h | 4 +- src/llama-model.cpp | 18 ++++-- src/models/eagle3.cpp | 9 ++- src/models/llama.cpp | 12 +--- src/models/openai-moe-iswa.cpp | 13 +--- src/models/qwen3.cpp | 13 +--- src/models/qwen3moe.cpp | 13 +--- 15 files changed, 78 insertions(+), 193 deletions(-) diff --git a/include/llama.h b/include/llama.h index a630da73b8..eb86981409 100644 --- a/include/llama.h +++ b/include/llama.h @@ -375,10 +375,6 @@ extern "C" { // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 - // EAGLE3 extraction configuration - const struct llama_model * target_model; // reference to target model - // only used to share embedding layer with eagle3 model - // [EXPERIMENTAL] // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) @@ -690,14 +686,6 @@ extern "C" { int32_t il_start, int32_t il_end); - // - // eagle3 (tmp) - // - - LLAMA_API void llama_set_eagle3( - struct llama_context * ctx, - const struct llama_model * model); - // // Memory // @@ -897,23 +885,6 @@ extern "C" { llama_seq_id dest_seq_id, llama_state_seq_flags flags); - // - // EAGLE3 draft model support - // - - // Get pointer to target model features extracted for EAGLE3 encoder - // Returns NULL if no features are available - // Format: [3*n_embd, n_tokens] - use model.hparams.n_embd and batch.n_tokens for dimensions - LLAMA_API const float * llama_get_eagle3_target_features(struct llama_context * ctx); - - // Set g_embeddings from EAGLE3 encoder output for decoder input - // g_embd: pointer to encoder output embeddings - LLAMA_API void llama_set_eagle3_g_embeddings( - struct llama_context * ctx, - const float * g_embd, - int32_t n_embd, - int32_t n_tokens); - // // Decoding // diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d364d927d9..b94318eee9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -65,6 +65,8 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.output_layer_inp.resize(hparams.n_layer, false); + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -165,8 +167,6 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; - cparams.eagle3_extract_enabled = false; - // initialized later cparams.pipeline_parallel = false; @@ -1170,30 +1170,14 @@ bool llama_context::set_adapter_cvec( return res; } -void llama_context::set_eagle3(const llama_model * model) { - // Initialize EAGLE3 feature extraction configuration - cparams.eagle3_extract_enabled = !!model; - if (!cparams.eagle3_extract_enabled) { - return; - } +void llama_context::set_output_layer_inp(uint32_t layer_id, bool enable) { + LLAMA_LOG_DEBUG("%s: layer_id = %d, enable = %d\n", __func__, layer_id, enable); + + GGML_ASSERT(layer_id < model.hparams.n_layer); + + cparams.output_layer_inp[layer_id] = enable; sched_need_reserve = true; - - const auto & eagle3_hparams = model->hparams; - - // Copy feature extraction layer indices from EAGLE3 model's hparams - eagle3.extract_layer_indices.assign( - eagle3_hparams.eagle3_extract_layers.begin(), - eagle3_hparams.eagle3_extract_layers.end() - ); - - // Allocate tensors array for extraction - eagle3.extract_tensors.resize(eagle3.extract_layer_indices.size(), nullptr); - - LLAMA_LOG_INFO("%s: EAGLE3 extraction enabled for layers [%d, %d, %d]\n", __func__, - eagle3.extract_layer_indices[0], - eagle3.extract_layer_indices[1], - eagle3.extract_layer_indices[2]); } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -1271,11 +1255,6 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll return nullptr; } - // EAGLE3: Extract intermediate layer features after graph execution - if (cparams.eagle3_extract_enabled && !eagle3.extract_tensors.empty()) { - extract_eagle3_features(ubatch); - } - ret = GGML_STATUS_SUCCESS; return res; @@ -1291,8 +1270,7 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - // EAGLE3: use 3*target_hidden_size for concatenated features input - const int64_t n_embd = (model.arch == LLM_ARCH_EAGLE3 && batch_inp.embd) ? 3 * hparams.eagle3_target_hidden_size : hparams.n_embd; + const int64_t n_embd = hparams.n_embd_inp(); const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 @@ -1987,7 +1965,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { has_embd = true; } - size_t backend_float_count = 0; size_t backend_token_count = 0; @@ -2299,27 +2276,6 @@ llm_graph_cb llama_context::graph_get_cb() const { ggml_set_name(cur, name); } - // EAGLE3: Extract intermediate layer features if this is an extraction point - if (cparams.eagle3_extract_enabled) { - static constexpr const char * prefix = "eagle3_extract_"; - static constexpr size_t prefix_len = 15; // strlen("eagle3_extract_") - - if (strncmp(name, prefix, prefix_len) == 0) { - // Parse the extraction index from the name (e.g., "eagle3_extract_0" -> 0) - size_t extract_idx = 0; - if (sscanf(name + prefix_len, "%zu", &extract_idx) == 1 && extract_idx < eagle3.extract_tensors.size()) { - // Mark as output tensor to ensure proper backend assignment - ggml_set_output(cur); - // Store this tensor reference for post-execution extraction - eagle3.extract_tensors[extract_idx] = cur; - LLAMA_LOG_DEBUG("%s: EAGLE3 stored tensor reference for extraction: " - "index=%zu, layer=%d, target_layer=%d, tensor=%s\n", - __func__, extract_idx, il, - eagle3.extract_layer_indices[extract_idx], name); - } - } - } - // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; @@ -3078,7 +3034,6 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, - /*.target_model =*/ nullptr, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, }; @@ -3094,12 +3049,6 @@ llama_context * llama_init_from_model( return nullptr; } - // Auto-setup for EAGLE3: set target embedding if target_model is provided - if (model->arch == LLM_ARCH_EAGLE3 && params.target_model) { - model->target_tok_embd = params.target_model->tok_embd; - LLAMA_LOG_INFO("%s: EAGLE3 auto-setup: using target model's embedding layer\n", __func__); - } - if (params.n_batch == 0 && params.n_ubatch == 0) { LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__); return nullptr; @@ -3381,16 +3330,6 @@ int32_t llama_set_adapter_cvec( return res ? 0 : -1; } -// -// eagle3 (tmp) -// - -void llama_set_eagle3( - llama_context * ctx, - const llama_model * model) { - ctx->set_eagle3(model); -} - // // memory // @@ -3703,36 +3642,6 @@ void llama_opt_epoch( callback_eval); } -// -// EAGLE3 member functions -// - -const float * llama_context::get_eagle3_target_features() const { - GGML_ASSERT(!eagle3.target_features.empty() && "EAGLE3 target features not extracted - call llama_encode() on target model first"); - return eagle3.target_features.data(); -} - -void llama_context::set_eagle3_g_embeddings(const float * g_embd, int32_t n_embd, int32_t n_tokens) { - GGML_ASSERT(g_embd != nullptr && "g_embeddings cannot be null"); - GGML_ASSERT(n_embd > 0 && n_tokens > 0 && "invalid dimensions"); - - const size_t size = n_embd * n_tokens; - eagle3.g_embeddings.resize(size); - std::memcpy(eagle3.g_embeddings.data(), g_embd, size * sizeof(float)); -} - -// -// C API wrappers -// - -const float * llama_get_eagle3_target_features(llama_context * ctx) { - return ctx->get_eagle3_target_features(); -} - -void llama_set_eagle3_g_embeddings(llama_context * ctx, const float * g_embd, int32_t n_embd, int32_t n_tokens) { - ctx->set_eagle3_g_embeddings(g_embd, n_embd, n_tokens); -} - // // ext // @@ -3740,3 +3649,7 @@ void llama_set_eagle3_g_embeddings(llama_context * ctx, const float * g_embd, in llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) { return ctx->memory_breakdown(); } + +void llama_set_output_layer_inp(struct llama_context * ctx, uint32_t layer_id, bool enable) { + ctx->set_output_layer_inp(layer_id, enable); +} diff --git a/src/llama-context.h b/src/llama-context.h index 7959c7709a..889c344642 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -106,8 +106,7 @@ struct llama_context { int32_t il_start, int32_t il_end); - // TODO: tmp - void set_eagle3(const llama_model * model); + void set_output_layer_inp(uint32_t layer_id, bool enable); // process a single ubatch with a specific graph type // if memory_context is provided, it will be applied first to the context's memory diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 48ab113bac..0bef318be4 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include #define LLAMA_MAX_SEQ 256 @@ -38,9 +39,10 @@ struct llama_cparams { bool warmup; bool op_offload; bool kv_unified; - bool eagle3_extract_enabled; // enable layer extraction for EAGLE3 speculative decoding bool pipeline_parallel; + std::vector output_layer_inp; + enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; diff --git a/src/llama-ext.h b/src/llama-ext.h index 8ce29d217c..cbdd69a1a0 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -88,3 +88,13 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// +// model/context data extraction +// + +LLAMA_API void llama_set_output_layer_inp(struct llama_context * ctx, uint32_t layer_id, bool enable); + +LLAMA_API ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model); +LLAMA_API void llama_model_set_tok_embd(struct llama_model * model, ggml_tensor * tensor); + diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1ff22fb9b2..0b2759ca62 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -805,6 +805,10 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + + t_layer_inp.resize(LLAMA_MAX_LAYERS); + std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr); + t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); @@ -833,7 +837,7 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { } } -void llm_graph_result::set_outputs() { +void llm_graph_result::set_outputs(const llm_graph_params & params) { if (t_logits != nullptr) { ggml_set_output(t_logits); } @@ -843,6 +847,14 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + { + const auto & output_layer_inp = params.cparams.output_layer_inp; + for (size_t il = 0; il < output_layer_inp.size(); ++il) { + if (output_layer_inp[il]) { + ggml_set_output(t_layer_inp[il]); + } + } + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); diff --git a/src/llama-graph.h b/src/llama-graph.h index b56077e9c5..18fd2c93ca 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -670,6 +670,8 @@ public: ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_layer_inp(int il) const { return t_layer_inp[il]; } + ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -678,7 +680,7 @@ public: void reset(); void set_inputs(const llama_ubatch * ubatch); - void set_outputs(); + void set_outputs(const llm_graph_params & params); // try to update the existing graph result using the new graph parameters in order to reuse it // this can only be done if we determine that the resulting graph using the new graph parameters @@ -698,10 +700,12 @@ public: ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; - std::map t_sampled_logits; - std::map t_candidates; - std::map t_sampled; - std::map t_sampled_probs; + std::vector t_layer_inp; + + std::map t_sampled_logits; + std::map t_candidates; + std::map t_sampled; + std::map t_sampled_probs; std::vector inputs; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 002d15d415..4da339764a 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -71,7 +71,7 @@ uint32_t llama_hparams::n_rot(uint32_t il) const { } uint32_t llama_hparams::n_embd_inp() const { - uint32_t n_embd_inp = n_embd; + uint32_t n_embd_inp = n_embd_inp_impl > 0 ? n_embd_inp_impl : n_embd; if (n_deepstack_layers > 0) { n_embd_inp += n_embd * n_deepstack_layers; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index fd12a597d0..ad071b5b68 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -42,6 +42,7 @@ struct llama_hparams { uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; + uint32_t n_embd_inp_impl = 0; uint32_t n_layer; int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_expert = 0; @@ -214,9 +215,6 @@ struct llama_hparams { // e.g., for 32-layer target: [2, 16, 29] (low, middle, high) std::array eagle3_extract_layers = {0, 0, 0}; - // EAGLE3 draft model - target model hidden size - uint32_t eagle3_target_hidden_size = 0; - // EAGLE3 draft model - apply hidden_norm before storing residual bool eagle3_norm_before_residual = false; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d3b3a1560b..ba56b5a63f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2774,9 +2774,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.eagle3_extract_layers[2]); // EAGLE3 target model hidden size - ml.get_key(LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, hparams.eagle3_target_hidden_size); + ml.get_key(LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, hparams.n_embd_inp_impl); LLAMA_LOG_INFO("%s: EAGLE3 target_hidden_size = %u (draft n_embd = %u)\n", __func__, - hparams.eagle3_target_hidden_size, hparams.n_embd); + hparams.n_embd_inp_impl, hparams.n_embd); // EAGLE3 norm_before_residual (optional, default false) // compatible with Readhat eagle3 speculator model @@ -7285,7 +7285,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_EAGLE3: { - const int64_t n_embd_target_features = 3 * hparams.eagle3_target_hidden_size; + const int64_t n_embd_inp = hparams.n_embd_inp(); const int64_t n_embd_attn_input = 2 * n_embd; // Get vocab size from the d2t tensor in the GGUF file (optional - only needed if EAGLE3 has different vocab_size than target) @@ -7302,7 +7302,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // Feature fusion layer: projects 3 target layers to draft hidden size - fc = create_tensor(tn(LLM_TENSOR_EAGLE3_FC, "weight"), {n_embd_target_features, n_embd}, 0); + fc = create_tensor(tn(LLM_TENSOR_EAGLE3_FC, "weight"), {n_embd_inp, n_embd}, 0); // Output layer (uses draft vocab size) output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -9178,7 +9178,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // TODO: move reranking logic here and generalize llm->build_dense_out(dense_2_out_layers, dense_2_out_layers_b, dense_3_out_layers); - llm->res->set_outputs(); + llm->res->set_outputs(params); return llm->res->get_gf(); } @@ -9583,3 +9583,11 @@ ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int } return model->devices[i].dev; } + +ggml_tensor * llama_model_get_tok_embd(const struct llama_model * model) { + return model->tok_embd; +} + +void llama_model_set_tok_embd(struct llama_model * model, ggml_tensor * tensor) { + model->tok_embd = tensor; +} diff --git a/src/models/eagle3.cpp b/src/models/eagle3.cpp index 69ac3be8c5..c8a377d113 100644 --- a/src/models/eagle3.cpp +++ b/src/models/eagle3.cpp @@ -1,14 +1,14 @@ #include "models.h" ggml_tensor * llm_build_eagle3_encode::build_inp_embd() const { - const int64_t n_embd_target_features = 3 * hparams.eagle3_target_hidden_size; + const int64_t n_embd_inp = hparams.n_embd_inp(); ggml_tensor * cur = nullptr; // Input: Target model features (3 layers concatenated: low, mid, high) // Data will be provided via ubatch->embd in encode_eagle3_features() - auto inp_target = std::make_unique(n_embd_target_features); - inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_target_features, n_tokens); + auto inp_target = std::make_unique(n_embd_inp); + inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, n_tokens); ggml_set_input(inp_target->embd); cur = inp_target->embd; @@ -27,6 +27,9 @@ llm_build_eagle3_encode::llm_build_eagle3_encode(const llama_model & model, cons cur = build_inp_embd(); + // sanity check + GGML_ASSERT(hparams.n_embd_inp() == model.fc->ne[0]); + // Feature fusion layer cur = build_lora_mm(model.fc, cur); cb(cur, "fc_out", -1); diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 2f171fd120..3c1af3ccc6 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -29,18 +29,10 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; - // EAGLE3: Extract intermediate layer features from target model at layer INPUT - if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { - static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; - for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { - if (eagle3->extract_layer_indices[i] == il) { - cb(inpL, eagle3_extract_names[i], il); - break; - } - } - } // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, diff --git a/src/models/openai-moe-iswa.cpp b/src/models/openai-moe-iswa.cpp index 10b86f255d..2713d0d3fa 100644 --- a/src/models/openai-moe-iswa.cpp +++ b/src/models/openai-moe-iswa.cpp @@ -14,22 +14,13 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); ggml_tensor * inpSA = inpL; - // EAGLE3: Extract intermediate layer features from target model at layer INPUT - if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { - static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; - for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { - if (eagle3->extract_layer_indices[i] == il) { - cb(inpL, eagle3_extract_names[i], il); - break; - } - } - } - // norm cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, diff --git a/src/models/qwen3.cpp b/src/models/qwen3.cpp index dfbbcce3a3..5772dc702f 100644 --- a/src/models/qwen3.cpp +++ b/src/models/qwen3.cpp @@ -19,18 +19,9 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; + res->t_layer_inp[il] = inpL; - // EAGLE3: Extract intermediate layer features from target model at layer INPUT - if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { - static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; - for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { - if (eagle3->extract_layer_indices[i] == il) { - cb(inpL, eagle3_extract_names[i], il); - break; - } - } - } + ggml_tensor * inpSA = inpL; // norm cur = build_norm(inpL, diff --git a/src/models/qwen3moe.cpp b/src/models/qwen3moe.cpp index d765ace697..aa69437a54 100644 --- a/src/models/qwen3moe.cpp +++ b/src/models/qwen3moe.cpp @@ -19,18 +19,9 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; + res->t_layer_inp[il] = inpL; - // EAGLE3: Extract intermediate layer features from target model at layer INPUT - if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { - static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; - for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { - if (eagle3->extract_layer_indices[i] == il) { - cb(inpL, eagle3_extract_names[i], il); - break; - } - } - } + ggml_tensor * inpSA = inpL; // norm cur = build_norm(inpL,