mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-03-17 16:44:07 +00:00
metal : fix l2 norm scale (#20493)
This commit is contained in:
@@ -1156,7 +1156,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
return op->src[2]->ne[0] % 32 == 0;
|
||||
return has_simdgroup_reduction && op->src[2]->ne[0] % 32 == 0;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
|
||||
@@ -3006,7 +3006,7 @@ kernel void kernel_l2_norm_impl(
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
||||
const float scale = 1.0f/max(sqrt(sumf), args.eps);
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||
y[i00] = x[i00] * scale;
|
||||
|
||||
Reference in New Issue
Block a user