mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-15 05:24:06 +00:00
cont : minor
This commit is contained in:
@@ -938,7 +938,7 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
|
||||
}
|
||||
|
||||
void common_speculative_draft(common_speculative * spec) {
|
||||
if (!spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -46,11 +46,10 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
|
||||
// TODO: implement [TAG_COMMON_SPECULATIVE_PROCESS]
|
||||
//bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
|
||||
|
||||
// generate drafts for the sequences specified in dparams
|
||||
// requires that `dparams.size() == n_seq` using during common_speculative_init()
|
||||
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
|
||||
void common_speculative_draft(common_speculative * spec);
|
||||
|
||||
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
||||
// informs the speculative context that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
|
||||
@@ -2484,6 +2484,7 @@ public:
|
||||
} else {
|
||||
//LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
|
||||
|
||||
// save the old buffer and allocate the new tensors in it
|
||||
auto buf = std::move(mbuf_cur.buf);
|
||||
|
||||
mbuf_cur = std::move(mbuf);
|
||||
|
||||
@@ -2204,45 +2204,45 @@ private:
|
||||
|
||||
if (spec) {
|
||||
common_speculative_get_draft_params(spec.get(), slot.id).drafting = false;
|
||||
}
|
||||
|
||||
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
|
||||
const int n_draft_max = slot.get_n_draft_max();
|
||||
const int n_draft_max = slot.get_n_draft_max();
|
||||
|
||||
if (n_draft_max > 0) {
|
||||
GGML_ASSERT(slot.can_speculate());
|
||||
if (n_draft_max > 0) {
|
||||
GGML_ASSERT(slot.can_speculate());
|
||||
|
||||
if (!slot.spec_draft.empty()) {
|
||||
// we have a previous (partial) draft to reuse
|
||||
if (use_ckpt_tgt) {
|
||||
GGML_ASSERT(!slot.spec_ckpt.empty());
|
||||
if (!slot.spec_draft.empty()) {
|
||||
// we have a previous (partial) draft to reuse
|
||||
if (use_ckpt_tgt) {
|
||||
GGML_ASSERT(!slot.spec_ckpt.empty());
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(slot.spec_i_batch.empty());
|
||||
|
||||
slot.spec_ckpt.update_pos(
|
||||
slot.prompt.n_tokens(),
|
||||
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id),
|
||||
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id));
|
||||
|
||||
if (use_ckpt_dft) {
|
||||
slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
}
|
||||
|
||||
slot.spec_prompt = slot.prompt.tokens.get_text_tokens();
|
||||
|
||||
common_speculative_get_draft_params(spec.get(), slot.id) = {
|
||||
/* .drafting = */ true,
|
||||
/* .n_max = */ n_draft_max,
|
||||
/* .n_past = */ slot.prompt.n_tokens(),
|
||||
/* .id_last = */ slot.sampled,
|
||||
/* .prompt = */ &slot.spec_prompt,
|
||||
/* .result = */ &slot.spec_draft,
|
||||
};
|
||||
|
||||
drafting.push_back(&slot);
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(slot.spec_i_batch.empty());
|
||||
|
||||
slot.spec_ckpt.update_pos(
|
||||
slot.prompt.n_tokens(),
|
||||
llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id),
|
||||
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id));
|
||||
|
||||
if (use_ckpt_dft) {
|
||||
slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
}
|
||||
|
||||
slot.spec_prompt = slot.prompt.tokens.get_text_tokens();
|
||||
|
||||
common_speculative_get_draft_params(spec.get(), slot.id) = {
|
||||
/* .drafting = */ true,
|
||||
/* .n_max = */ n_draft_max,
|
||||
/* .n_past = */ slot.prompt.n_tokens(),
|
||||
/* .id_last = */ slot.sampled,
|
||||
/* .prompt = */ &slot.spec_prompt,
|
||||
/* .result = */ &slot.spec_draft,
|
||||
};
|
||||
|
||||
drafting.push_back(&slot);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2256,29 +2256,33 @@ private:
|
||||
for (auto * slot_ptr : drafting) {
|
||||
auto & slot = *slot_ptr;
|
||||
|
||||
slot.n_draft_total += slot.spec_draft.size();
|
||||
auto & draft = slot.spec_draft;
|
||||
auto & ckpt = slot.spec_ckpt;
|
||||
|
||||
slot.n_draft_total += draft.size();
|
||||
|
||||
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
|
||||
if (ctx_dft) {
|
||||
slot.spec_ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, slot.spec_ckpt.pos_max + 1, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
if (!draft.empty()) {
|
||||
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
|
||||
if (!slot.spec_draft.empty()) {
|
||||
if (use_ckpt_tgt) {
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
slot.spec_ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
//const int64_t t_total = ggml_time_us() - t_start;
|
||||
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
|
||||
|
||||
SLT_DBG(slot, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n",
|
||||
slot.spec_ckpt.pos_min, slot.spec_ckpt.pos_max, slot.prompt.n_tokens(),
|
||||
(float) slot.spec_ckpt.size() / 1024 / 1024, (float) slot.spec_ckpt.data_dft.size() / 1024 / 1024);
|
||||
ckpt.pos_min, ckpt.pos_max, slot.prompt.n_tokens(),
|
||||
(float) ckpt.size() / 1024 / 1024,
|
||||
(float) ckpt.data_dft.size() / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user