mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-12 12:04:08 +00:00
cont : pass seq_id
[no ci]
This commit is contained in:
@@ -176,6 +176,8 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
bool use_ckpt = false;
|
||||
common_speculative_checkpoint ckpt;
|
||||
|
||||
llama_seq_id seq_id;
|
||||
|
||||
common_sampler * smpl;
|
||||
|
||||
llama_batch batch;
|
||||
@@ -185,11 +187,13 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt,
|
||||
llama_context * ctx_dft,
|
||||
bool use_ckpt)
|
||||
bool use_ckpt,
|
||||
llama_seq_id seq_id)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
, ctx_dft(ctx_dft)
|
||||
, use_ckpt(use_ckpt)
|
||||
, seq_id(seq_id)
|
||||
{
|
||||
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||
smpl = nullptr;
|
||||
@@ -241,15 +245,14 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
}
|
||||
|
||||
size_t create_checkpoint(int n_tokens_prompt) {
|
||||
int slot_id = 0;
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), seq_id);
|
||||
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
ckpt.n_tokens = n_tokens_prompt;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
}
|
||||
@@ -260,14 +263,14 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
}
|
||||
|
||||
size_t restore_checkpoint() {
|
||||
int slot_id = 0;
|
||||
int seq_id = 0;
|
||||
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != ckpt.size()) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
}
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, ckpt.pos_max + 1, -1);
|
||||
|
||||
return n;
|
||||
}
|
||||
@@ -896,7 +899,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
||||
|
||||
// initialization of the speculative decoding system
|
||||
//
|
||||
common_speculative * common_speculative_init(common_params_speculative & params) {
|
||||
common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id) {
|
||||
// Compute the implementations to use based on the config and their order of preference
|
||||
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
|
||||
{
|
||||
@@ -961,7 +964,8 @@ common_speculative * common_speculative_init(common_params_speculative & params)
|
||||
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
|
||||
/* .ctx_tgt = */ params.draft.ctx_tgt,
|
||||
/* .ctx_dft = */ params.draft.ctx_dft,
|
||||
/* .use_ckpt = */ params.draft.use_ckpt
|
||||
/* .use_ckpt = */ params.draft.use_ckpt,
|
||||
/* .seq_id = */ seq_id
|
||||
));
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
||||
// convert type to string
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||
|
||||
common_speculative * common_speculative_init(common_params_speculative & params);
|
||||
common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id);
|
||||
|
||||
void common_speculative_free(common_speculative * spec);
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
|
||||
// init the speculator
|
||||
const auto & params_spec = params.speculative;
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative);
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative, 0);
|
||||
|
||||
common_speculative_begin(spec, prompt_tgt);
|
||||
|
||||
|
||||
@@ -912,7 +912,7 @@ private:
|
||||
// try speculative decoding
|
||||
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
try {
|
||||
slot.spec.reset(common_speculative_init(params_base.speculative));
|
||||
slot.spec.reset(common_speculative_init(params_base.speculative, slot.id));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user