mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 04:24:17 +00:00
cont : fix
This commit is contained in:
@@ -1269,7 +1269,7 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float *
|
||||
|
||||
// the min position in the batch for each sequence
|
||||
llama_pos seq_pos_min[LLAMA_MAX_SEQ];
|
||||
std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MIN);
|
||||
std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
|
||||
|
||||
for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
@@ -1309,7 +1309,7 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float *
|
||||
auto & idxs = seq_idxs[seq_id];
|
||||
|
||||
if (!alibi) {
|
||||
if (ii > 0 && seq_srct.find(seq_id) != seq_srct.end()) {
|
||||
if (seq_srct.find(seq_id) != seq_srct.end()) {
|
||||
const uint32_t srct = seq_srct[seq_id];
|
||||
|
||||
const uint64_t idst_prev = n_kv*srct;
|
||||
@@ -1337,10 +1337,10 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float *
|
||||
|
||||
j = idxs[jj];
|
||||
}
|
||||
}
|
||||
|
||||
if (cells.is_empty(j)) {
|
||||
goto skip;
|
||||
}
|
||||
if (cells.is_empty(j)) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
// mask the token if not the same sequence
|
||||
|
||||
Reference in New Issue
Block a user