allow to toggle embedding mode

This commit is contained in:
Douglas Hanley
2024-03-07 11:55:27 -06:00
parent f618e5060a
commit bd3d9fbfed
5 changed files with 21 additions and 12 deletions

View File

@@ -1684,7 +1684,6 @@ struct llama_cparams {
bool embeddings;
bool offload_kqv;
bool causal_attn;
enum llama_pooling_type pooling_type;
ggml_backend_sched_eval_callback cb_eval;
@@ -8030,7 +8029,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
if (cparams.causal_attn) {
GGML_ASSERT(
(hparams.causal_attn || cparams.embeddings) &&
"non-causal attention with generative models is not supported"
);
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
// But if cparams.embeddings is set, the attention will be non-causal nonetheless.
if (!cparams.embeddings) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;
@@ -8055,7 +8061,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
} else {
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
// with causal attention, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
@@ -11998,7 +12004,6 @@ struct llama_context_params llama_context_default_params() {
/*.logits_all =*/ false,
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
/*.causal_attn =*/ true,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};
@@ -12150,7 +12155,6 @@ struct llama_context * llama_new_context_with_model(
cparams.defrag_thold = params.defrag_thold;
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.causal_attn = params.causal_attn;
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
@@ -13165,6 +13169,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
ctx->abort_callback_data = abort_callback_data;
}
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
ctx->cparams.embeddings = embeddings;
}
struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens,