cont : pass seq_id

[no ci]
This commit is contained in:
Georgi Gerganov
2026-05-07 10:14:18 +03:00
parent 2e389e1df9
commit ab8875567c
4 changed files with 18 additions and 14 deletions

View File

@@ -176,6 +176,8 @@ struct common_speculative_state_draft : public common_speculative_state {
bool use_ckpt = false;
common_speculative_checkpoint ckpt;
llama_seq_id seq_id;
common_sampler * smpl;
llama_batch batch;
@@ -185,11 +187,13 @@ struct common_speculative_state_draft : public common_speculative_state {
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_dft,
bool use_ckpt)
bool use_ckpt,
llama_seq_id seq_id)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
, use_ckpt(use_ckpt)
, seq_id(seq_id)
{
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
smpl = nullptr;
@@ -241,15 +245,14 @@ struct common_speculative_state_draft : public common_speculative_state {
}
size_t create_checkpoint(int n_tokens_prompt) {
int slot_id = 0;
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), seq_id);
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
ckpt.n_tokens = n_tokens_prompt;
ckpt.data.resize(checkpoint_size);
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}
@@ -260,14 +263,14 @@ struct common_speculative_state_draft : public common_speculative_state {
}
size_t restore_checkpoint() {
int slot_id = 0;
int seq_id = 0;
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());
}
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, ckpt.pos_max + 1, -1);
return n;
}
@@ -896,7 +899,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(common_params_speculative & params) {
common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id) {
// Compute the implementations to use based on the config and their order of preference
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
{
@@ -961,7 +964,8 @@ common_speculative * common_speculative_init(common_params_speculative & params)
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ params.draft.ctx_tgt,
/* .ctx_dft = */ params.draft.ctx_dft,
/* .use_ckpt = */ params.draft.use_ckpt
/* .use_ckpt = */ params.draft.use_ckpt,
/* .seq_id = */ seq_id
));
break;
}

View File

@@ -14,7 +14,7 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
common_speculative * common_speculative_init(common_params_speculative & params);
common_speculative * common_speculative_init(common_params_speculative & params, llama_seq_id seq_id);
void common_speculative_free(common_speculative * spec);

View File

@@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
// init the speculator
const auto & params_spec = params.speculative;
struct common_speculative * spec = common_speculative_init(params.speculative);
struct common_speculative * spec = common_speculative_init(params.speculative, 0);
common_speculative_begin(spec, prompt_tgt);

View File

@@ -912,7 +912,7 @@ private:
// try speculative decoding
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
try {
slot.spec.reset(common_speculative_init(params_base.speculative));
slot.spec.reset(common_speculative_init(params_base.speculative, slot.id));
} catch (const std::exception & e) {
SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what());
}