fix batch size

This commit is contained in:
Aman Gupta
2026-05-11 12:22:37 +08:00
parent a428b010ab
commit c417ddfc74

View File

@@ -413,11 +413,11 @@ struct common_speculative_state_mtp : public common_speculative_impl {
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
const int32_t n_ub = (int32_t) llama_n_ubatch(ctx_dft);
batch = llama_batch_init(/*n_tokens=*/ n_ub, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
const int32_t n_b = (int32_t) llama_n_batch(ctx_dft);
batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
// llama_batch_init allocates only one of token/embd; MTP needs both.
// TODO: fix, how to call without malloc
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_ub);
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b);
smpls.resize(n_seq);
for (auto & s : smpls) {