mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-15 05:24:06 +00:00
cuda : avoid warp_reduce for smax
This commit is contained in:
@@ -6621,7 +6621,6 @@ static __global__ void flash_attn_ext_f16(
|
||||
M[j] = __hmax(M[j], s);
|
||||
}
|
||||
|
||||
smax = warp_reduce_max(smax);
|
||||
M[j] = warp_reduce_max(M[j]);
|
||||
|
||||
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
|
||||
@@ -6649,6 +6648,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
}
|
||||
|
||||
smax = warp_reduce_max(smax);
|
||||
|
||||
// skip -INF blocks
|
||||
if (__hisinf(smax) == -1) {
|
||||
continue;
|
||||
|
||||
Reference in New Issue
Block a user