mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-14 13:04:08 +00:00
* spec : refactor * spec : drop support for incompatible vocabs * spec : update common_speculative_init() * cont : pass seq_id * cont : dedup ctx_seq_rm_type * server : sketch the ctx_dft decode loop * server : draft prompt cache and checkpoints * server : improve ctx names * server, spec : transition to unified spec context * cont : sync main and drft contexts * cont : async drft eval when possible * cont : handle non-ckpt models * cont : pass correct n_past for drafting * cont : process images throught the draft context * spec : handle draft running out of context * server : fix mtmd draft processing * server : fix URL for draft model * server : add comment * server : clean-up + dry * speculative-simple : update * spec : fix n_past type * server : fix slot ctx_drft ptr * tools : update readme * naming : improve consistency * spec : refactor for multi-sequence speculative context * cont : prepare params * cont : prepare params * spec : support parallel drafts * server : support parallel drafting * llama : reuse device buffers when possible * server, spec : clean-up * cont : clean-up * cont : minor * spec : reset `drafting` flag at the end * spec : introduce `common_speculative_process()` * spec : allow for multiple spec types (chain of speculators) * replace old type field of type common_speculative_type in the common_params_speculative struct with a vector to allow multiple types to be specified * introduce common_get_enabled_speculative_impls(const std::vector<enum common_speculative_type>) to figure out which implementations the user has enabled * introduce common_speculative_type_from_names(const std::vector<std::string> & names) to parse the already user provided spec types * all speculators run sequentially, best one wins (we verify its drafted tokens) * maximize expected accepted tokens for current round by calculating the product between the probability of accepting current token (n_acc_tokens / n_gen_drafts) and the draft's length --------- Co-authored-by: Petros Sideris <petros.sideris@nokia.com>
70 lines
2.7 KiB
C++
70 lines
2.7 KiB
C++
#pragma once
|
|
|
|
#include "llama.h"
|
|
#include "common.h"
|
|
|
|
struct common_speculative;
|
|
|
|
// comma separated list the provided types
|
|
std::string common_speculative_type_name_str(const std::vector<enum common_speculative_type> & types);
|
|
|
|
// comma separated list of all types
|
|
const char * common_speculative_all_types_str();
|
|
|
|
// parse user provided types
|
|
std::vector<enum common_speculative_type> common_speculative_types_from_names(const std::vector<std::string> & names);
|
|
|
|
// convert string to type
|
|
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
|
|
|
|
// convert type to string
|
|
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
|
|
|
common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq);
|
|
|
|
void common_speculative_free(common_speculative * spec);
|
|
|
|
struct common_speculative_draft_params {
|
|
// this flag is used to chain the drafts through all the available implementations
|
|
// after the first successful draft from an implementation, we set it
|
|
// to false to prevent further drafts for that sequence
|
|
// at the end of the draft() call, all drafting flags will be reset to false
|
|
bool drafting = false;
|
|
|
|
// overrides individual configurations (-1 disabled)
|
|
// can be used to constraint the max draft based on the remaining context size
|
|
int32_t n_max = -1;
|
|
|
|
llama_pos n_past;
|
|
llama_token id_last;
|
|
|
|
// TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls
|
|
const llama_tokens * prompt;
|
|
|
|
// the generated draft from the last _draft() call
|
|
llama_tokens * result;
|
|
};
|
|
|
|
common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id);
|
|
|
|
// optionally call once at the beginning of a new generation
|
|
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt);
|
|
|
|
// process the batch and update the internal state of the speculative context
|
|
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
|
|
|
|
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
|
|
void common_speculative_draft(common_speculative * spec);
|
|
|
|
// informs the speculative context that n_accepted tokens were accepted by the target model
|
|
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
|
|
|
|
// print statistics about the speculative decoding
|
|
void common_speculative_print_stats(const common_speculative * spec);
|
|
|
|
struct common_speculative_deleter {
|
|
void operator()(common_speculative * s) { common_speculative_free(s); }
|
|
};
|
|
|
|
typedef std::unique_ptr<common_speculative, common_speculative_deleter> common_speculative_ptr;
|