mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-03-17 16:44:07 +00:00
cuda/hip: fix loop unrolling in ssm-conv (#20369)
This commit is contained in:
@@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
||||
int row = tid / load_cols;
|
||||
int col = tid % load_cols;
|
||||
#pragma unroll
|
||||
for (int idx = tid; idx < total_elems; idx += split_d_inner) {
|
||||
for (int idx = 0; idx < total_elems; idx += split_d_inner) {
|
||||
if (row < (int)split_d_inner) {
|
||||
smem[row * n_cols + col] = x_block[row * stride_x + col];
|
||||
}
|
||||
@@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
||||
col += split_d_inner;
|
||||
row += col / load_cols;
|
||||
col = col % load_cols;
|
||||
if (idx >= total_elems - tid - split_d_inner) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user