Files
llama.cpp/common/speculative.h
Georgi Gerganov 68e7ea3eab spec : parallel drafting support (#22838)
* spec : refactor

* spec : drop support for incompatible vocabs

* spec : update common_speculative_init()

* cont : pass seq_id

* cont : dedup ctx_seq_rm_type

* server : sketch the ctx_dft decode loop

* server : draft prompt cache and checkpoints

* server : improve ctx names

* server, spec : transition to unified spec context

* cont : sync main and drft contexts

* cont : async drft eval when possible

* cont : handle non-ckpt models

* cont : pass correct n_past for drafting

* cont : process images throught the draft context

* spec : handle draft running out of context

* server : fix mtmd draft processing

* server : fix URL for draft model

* server : add comment

* server : clean-up + dry

* speculative-simple : update

* spec : fix n_past type

* server : fix slot ctx_drft ptr

* tools : update readme

* naming : improve consistency

* spec : refactor for multi-sequence speculative context

* cont : prepare params

* cont : prepare params

* spec : support parallel drafts

* server : support parallel drafting

* llama : reuse device buffers when possible

* server, spec : clean-up

* cont : clean-up

* cont : minor

* spec : reset `drafting` flag at the end

* spec : introduce `common_speculative_process()`

* spec : allow for multiple spec types (chain of speculators)

* replace old type field of type common_speculative_type in the
  common_params_speculative struct with a vector to allow multiple
  types to be specified

* introduce common_get_enabled_speculative_impls(const std::vector<enum common_speculative_type>)
  to figure out which implementations the user has enabled

* introduce common_speculative_type_from_names(const std::vector<std::string> & names)
  to parse the already user provided spec types

* all speculators run sequentially, best one wins (we verify its drafted tokens)

* maximize expected accepted tokens for current round by calculating the
  product between the probability of accepting current token (n_acc_tokens / n_gen_drafts)
  and the draft's length

---------

Co-authored-by: Petros Sideris <petros.sideris@nokia.com>
2026-05-11 19:09:43 +03:00

70 lines
2.7 KiB
C++

#pragma once
#include "llama.h"
#include "common.h"
struct common_speculative;
// comma separated list the provided types
std::string common_speculative_type_name_str(const std::vector<enum common_speculative_type> & types);
// comma separated list of all types
const char * common_speculative_all_types_str();
// parse user provided types
std::vector<enum common_speculative_type> common_speculative_types_from_names(const std::vector<std::string> & names);
// convert string to type
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq);
void common_speculative_free(common_speculative * spec);
struct common_speculative_draft_params {
// this flag is used to chain the drafts through all the available implementations
// after the first successful draft from an implementation, we set it
// to false to prevent further drafts for that sequence
// at the end of the draft() call, all drafting flags will be reset to false
bool drafting = false;
// overrides individual configurations (-1 disabled)
// can be used to constraint the max draft based on the remaining context size
int32_t n_max = -1;
llama_pos n_past;
llama_token id_last;
// TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls
const llama_tokens * prompt;
// the generated draft from the last _draft() call
llama_tokens * result;
};
common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id);
// optionally call once at the beginning of a new generation
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt);
// process the batch and update the internal state of the speculative context
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
void common_speculative_draft(common_speculative * spec);
// informs the speculative context that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);
struct common_speculative_deleter {
void operator()(common_speculative * s) { common_speculative_free(s); }
};
typedef std::unique_ptr<common_speculative, common_speculative_deleter> common_speculative_ptr;