mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 12:34:05 +00:00
allow to toggle embedding mode
This commit is contained in:
18
llama.cpp
18
llama.cpp
@@ -1684,7 +1684,6 @@ struct llama_cparams {
|
||||
|
||||
bool embeddings;
|
||||
bool offload_kqv;
|
||||
bool causal_attn;
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
@@ -8030,7 +8029,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
|
||||
if (cparams.causal_attn) {
|
||||
GGML_ASSERT(
|
||||
(hparams.causal_attn || cparams.embeddings) &&
|
||||
"non-causal attention with generative models is not supported"
|
||||
);
|
||||
|
||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||
// But if cparams.embeddings is set, the attention will be non-causal nonetheless.
|
||||
if (!cparams.embeddings) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
@@ -8055,7 +8061,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
|
||||
// with causal attention, the mask needs to match the kv cache size
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
|
||||
|
||||
@@ -11998,7 +12004,6 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.logits_all =*/ false,
|
||||
/*.embeddings =*/ false,
|
||||
/*.offload_kqv =*/ true,
|
||||
/*.causal_attn =*/ true,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
};
|
||||
@@ -12150,7 +12155,6 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.causal_attn = params.causal_attn;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
@@ -13165,6 +13169,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
|
||||
ctx->abort_callback_data = abort_callback_data;
|
||||
}
|
||||
|
||||
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
|
||||
ctx->cparams.embeddings = embeddings;
|
||||
}
|
||||
|
||||
struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
|
||||
Reference in New Issue
Block a user