mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-12 20:14:09 +00:00
cont : clean-up
This commit is contained in:
@@ -616,7 +616,6 @@ private:
|
||||
common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
|
||||
common_speculative_ptr spec;
|
||||
common_speculative_draft_params_vec spec_dparams;
|
||||
|
||||
bool add_bos_token = true;
|
||||
|
||||
@@ -838,7 +837,6 @@ private:
|
||||
if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
try {
|
||||
spec.reset(common_speculative_init(params_base.speculative, params_base.n_parallel));
|
||||
spec_dparams.resize(params_base.n_parallel);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
|
||||
}
|
||||
@@ -2186,10 +2184,11 @@ private:
|
||||
// track if given slot can be batched with slots already in the batch
|
||||
server_slot * slot_batched = nullptr;
|
||||
|
||||
// first, process slots that are speculative decoding
|
||||
for (auto & slot : slots) {
|
||||
spec_dparams[slot.id].drafting = false;
|
||||
std::vector<server_slot *> generating;
|
||||
std::vector<server_slot *> drafting;
|
||||
|
||||
// determine which slots are generating and drafting
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue;
|
||||
}
|
||||
@@ -2201,6 +2200,12 @@ private:
|
||||
continue;
|
||||
}
|
||||
|
||||
generating.push_back(&slot);
|
||||
|
||||
if (spec) {
|
||||
common_speculative_get_draft_params(spec.get(), slot.id).drafting = false;
|
||||
}
|
||||
|
||||
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
|
||||
@@ -2226,34 +2231,30 @@ private:
|
||||
slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
}
|
||||
|
||||
spec_dparams[slot.id] = common_speculative_draft_params {
|
||||
slot.spec_prompt = slot.prompt.tokens.get_text_tokens();
|
||||
|
||||
common_speculative_get_draft_params(spec.get(), slot.id) = {
|
||||
/* .drafting = */ true,
|
||||
/* .n_max = */ n_draft_max,
|
||||
/* .n_past = */ slot.prompt.n_tokens(),
|
||||
/* .id_last = */ slot.sampled,
|
||||
/* .prompt = */ nullptr,
|
||||
/* .prompt = */ &slot.spec_prompt,
|
||||
/* .result = */ &slot.spec_draft,
|
||||
};
|
||||
|
||||
drafting.push_back(&slot);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generate the actual drafts (if any)
|
||||
{
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) spec_dparams.size(); seq_id++) {
|
||||
auto & slot = slots[seq_id];
|
||||
auto & dp = spec_dparams[seq_id];
|
||||
|
||||
slot.spec_prompt = slot.prompt.tokens.get_text_tokens();
|
||||
|
||||
dp.prompt = &slot.spec_prompt;
|
||||
}
|
||||
|
||||
common_speculative_draft(spec.get(), spec_dparams);
|
||||
common_speculative_draft(spec.get());
|
||||
}
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) spec_dparams.size(); seq_id++) {
|
||||
auto & slot = slots[seq_id];
|
||||
// make checkpoints if needed
|
||||
for (auto * slot_ptr : drafting) {
|
||||
auto & slot = *slot_ptr;
|
||||
|
||||
slot.n_draft_total += slot.spec_draft.size();
|
||||
|
||||
@@ -2282,20 +2283,9 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
slot_batched = nullptr;
|
||||
|
||||
// add the speculative drafts to the batch, or simply add the sampled tokens
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// check if we can batch this slot with the previous one
|
||||
if (!slot_batched) {
|
||||
slot_batched = &slot;
|
||||
} else if (!slot_batched->can_batch_with(slot)) {
|
||||
continue;
|
||||
}
|
||||
// update the batch with the sampled/drafted tokens
|
||||
for (auto * slot_ptr : generating) {
|
||||
auto & slot = *slot_ptr;
|
||||
|
||||
slot.update_batch(batch);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user