Compare commits

...

5 Commits
b6848 ... b6853

Author SHA1 Message Date
Johannes Gäßler
945501f5ea llama: fix leaked buffers for mmap + split files (#16765) 2025-10-27 09:17:31 +01:00
Aman Gupta
75cbdd3fce test-backend-ops: print failed tests at the end (#16785) 2025-10-27 09:25:10 +08:00
tamarPal
2b9bd9bf4e sycl: add ROLL operation support (#16665)
* sycl: add ROLL operation support

- Implement ggml_sycl_roll function for F32 tensors
- Add multi-axis roll operation with SYCL kernel
- Support all 4 tensor dimensions with proper shift normalization
- Add roll.cpp and roll.hpp to SYCL backend
- Update backend dispatch and supports_op for GGML_OP_ROLL
- Tests: 17662/17662 pass with identical CPU reference results

* fix: remove trailing whitespace from roll.cpp

- Fix EditorConfig violations in ggml/src/ggml-sycl/roll.cpp
- Remove trailing spaces from lines 6, 11, 28, 47, 58, 60

* ci: retrigger

* sycl: remove wait() calls from ROLL operation

* fix: editorconfig — LF endings + final newline for roll.hpp

---------

Co-authored-by: tamarPal <tamarPal@example.com>
2025-10-27 09:20:24 +08:00
shani-f
59fc1ec8e8 sycl: add REPEAT_BACK operation support (#16734)
* SYCL repeat_back v1 — add core op + switch case

* Implement repeat_back SYCL operation and minor fixes

* Update ggml/src/ggml-sycl/repeat_back.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update ggml/src/ggml-sycl/repeat_back.hpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update ggml/src/ggml-sycl/ggml-sycl.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-10-27 09:19:50 +08:00
Aman Gupta
75d33b9302 CUDA: support for weight clamp in top-k norm (#16702) 2025-10-27 09:06:16 +08:00
11 changed files with 344 additions and 52 deletions

View File

@@ -2976,7 +2976,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
if (ops.size() == topk_moe_ops_with_norm.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx+8];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
return true;
@@ -2986,7 +2986,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
if (ops.size() == topk_moe_ops.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx+4];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
return true;
}
@@ -3125,17 +3125,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
if (!disable_fusion) {
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i+8];
ggml_tensor * selected_experts = cgraph->nodes[i+3];
ggml_tensor * weights = cgraph->nodes[i + 9];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_tensor * clamp = cgraph->nodes[i + 7];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
/*delayed softmax*/ false);
i += 8;
/*delayed softmax*/ false, clamp);
i += 9;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
ggml_tensor * weights = cgraph->nodes[i+4];
ggml_tensor * selected_experts = cgraph->nodes[i+3];
ggml_tensor * weights = cgraph->nodes[i + 4];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
/*delayed softmax*/ false);
i += 4;

View File

@@ -2,6 +2,7 @@
#include "ggml.h"
#include "topk-moe.cuh"
#include <cmath>
#include <initializer_list>
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
@@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float * weights,
int32_t * ids,
const int n_rows,
const int n_expert_used) {
const int n_expert_used,
const float clamp_val) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= n_rows) {
return;
@@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
if constexpr (with_norm) {
wt_sum = warp_reduce_sum(wt_sum);
wt_sum = max(wt_sum, clamp_val);
const float inv_sum = 1.0f / wt_sum;
for (int i = 0; i < experts_per_thread; i++) {
@@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
weights[idx] = output_weights[i];
}
}
if (!with_norm) {
GGML_UNUSED(clamp_val);
}
}
template <bool with_norm, bool delayed_softmax = false>
@@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
int32_t * ids,
const int n_rows,
const int n_expert,
const int n_expert_used) {
const int n_expert_used,
const float clamp_val) {
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
@@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
switch (n_expert) {
case 1:
topk_moe_cuda<1, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 2:
topk_moe_cuda<2, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 4:
topk_moe_cuda<4, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 8:
topk_moe_cuda<8, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 16:
topk_moe_cuda<16, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 32:
topk_moe_cuda<32, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 64:
topk_moe_cuda<64, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 128:
topk_moe_cuda<128, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 256:
topk_moe_cuda<256, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 512:
topk_moe_cuda<512, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
default:
GGML_ASSERT(false && "fatal error");
@@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax) {
const bool delayed_softmax,
ggml_tensor * clamp) {
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const int n_expert_used = weights->ne[1];
float clamp_val = -INFINITY;
if (with_norm) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
if (clamp) {
clamp_val = ggml_get_op_params_f32(clamp, 0);
}
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
} else {
GGML_ASSERT(clamp == nullptr);
if (delayed_softmax) {
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
} else {
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
}
}
}
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
float scale = 1.0f;
float max_bias = 0.0f;
@@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return false;
}
if (clamp) {
if (clamp->op != GGML_OP_CLAMP) {
return false;
}
float max_val = ggml_get_op_params_f32(clamp, 1);
if (max_val != INFINITY) {
return false;
}
}
return true;
}
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };

View File

@@ -8,8 +8,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax = false);
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

View File

@@ -32,6 +32,7 @@
#include "pad.hpp"
#include "quantize.hpp"
#include "quants.hpp"
#include "roll.hpp"
#include "rope.hpp"
#include "set_rows.hpp"
#include "softmax.hpp"

View File

@@ -48,6 +48,7 @@
#include "ggml-sycl/set.hpp"
#include "ggml-sycl/sycl_hw.hpp"
#include "ggml-sycl/getrows.hpp"
#include "ggml-sycl/repeat_back.hpp"
#include "ggml-sycl/quantize.hpp"
#include "ggml.h"
@@ -2615,6 +2616,10 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_repeat_back(ctx, dst);
}
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
@@ -3679,6 +3684,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_REPEAT:
ggml_sycl_repeat(ctx, dst);
break;
case GGML_OP_REPEAT_BACK:
ggml_sycl_repeat_back(ctx, dst);
break;
case GGML_OP_GET_ROWS:
ggml_sycl_get_rows(ctx, dst);
break;
@@ -3913,6 +3921,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_ROLL:
ggml_sycl_roll(ctx, dst);
break;
case GGML_OP_ARANGE:
ggml_sycl_arange(ctx, dst);
break;
@@ -4516,6 +4527,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
ggml_type src0_type = op->src[0]->type;
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
}
case GGML_OP_REPEAT_BACK:
{
ggml_type src0_type = op->src[0]->type;
return src0_type == GGML_TYPE_F32;
}
case GGML_OP_DUP:
case GGML_OP_ARGMAX:
case GGML_OP_NONE:
@@ -4586,6 +4602,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
case GGML_OP_ROLL:
return op->type == GGML_TYPE_F32;
case GGML_OP_ARANGE:
return op->type == GGML_TYPE_F32;
default:

View File

@@ -0,0 +1,56 @@
#include "repeat_back.hpp"
#include "common.hpp"
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const float * src0_dd = (const float *) dst->src[0]->data;
float * dst_dd = (float *) dst->data;
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
ne03 = dst->src[0]->ne[3];
const int nr0 = (int) (ne00 / ne0);
const int nr1 = (int) (ne01 / ne1);
const int nr2 = (int) (ne02 / ne2);
const int nr3 = (int) (ne03 / ne3);
const size_t total = ne0 * ne1 * ne2 * ne3;
const int BLOCK_SIZE = 256;
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
queue_ptr stream = ctx.stream();
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
const size_t i = item_ct1.get_global_linear_id();
if (i >= total) {
return;
}
const int i0 = i % ne0;
const int i1 = (i / ne0) % ne1;
const int i2 = (i / (ne0 * ne1)) % ne2;
const int i3 = i / (ne0 * ne1 * ne2);
float acc = 0.0f;
for (int j3 = 0; j3 < nr3; ++j3) {
for (int j2 = 0; j2 < nr2; ++j2) {
for (int j1 = 0; j1 < nr1; ++j1) {
for (int j0 = 0; j0 < nr0; ++j0) {
acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
(i3 + j3 * ne3) * ne00 * ne01 * ne02];
}
}
}
}
dst_dd[i] = acc;
});
}

View File

@@ -0,0 +1,8 @@
#ifndef GGML_SYCL_REPEAT_BACK_HPP
#define GGML_SYCL_REPEAT_BACK_HPP
#include "common.hpp"
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_REPEAT_BACK_HPP

122
ggml/src/ggml-sycl/roll.cpp Normal file
View File

@@ -0,0 +1,122 @@
#include "roll.hpp"
#include "common.hpp"
using namespace sycl;
static inline int wrap_add(int i, int shift, int n) {
int s = i + shift;
return (s >= n) ? (s - n) : s;
}
static void kernel_roll_fused_i0_i1(
queue &q,
const float *src_d,
float *dst_d,
int ne0, int ne1, int ne2, int ne3,
int sh0, int sh1, int sh2, int sh3)
{
if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;
const int stride1 = ne0;
const int stride2 = ne0 * ne1;
const int stride3 = ne0 * ne1 * ne2;
const int shNe0 = (ne0 - sh0) % ne0;
const int shNe1 = (ne1 - sh1) % ne1;
const int shNe2 = (ne2 - sh2) % ne2;
const int shNe3 = (ne3 - sh3) % ne3;
const size_t g0 = (size_t) ne3;
const size_t g1 = (size_t) ne2;
const size_t g2 = (size_t) (ne1 * ne0);
const range<3> global{ g0, g1, g2 };
q.submit([&](handler &h) {
h.parallel_for(global, [=](id<3> idx) {
const int i3 = (int) idx[0];
const int i2 = (int) idx[1];
const int fused = (int) idx[2];
const int i1 = fused / ne0;
const int i0 = fused - i1 * ne0; // fused % ne0
const int idx_dst = i0
+ i1 * stride1
+ i2 * stride2
+ i3 * stride3;
const int s0 = wrap_add(i0, shNe0, ne0);
const int s1 = wrap_add(i1, shNe1, ne1);
const int s2 = wrap_add(i2, shNe2, ne2);
const int s3 = wrap_add(i3, shNe3, ne3);
const int idx_src = s0
+ s1 * stride1
+ s2 * stride2
+ s3 * stride3;
dst_d[idx_dst] = src_d[idx_src];
});
});
}
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const ggml_tensor *src = dst->src[0];
GGML_ASSERT(src && src->type == GGML_TYPE_F32);
const int ne0 = (int) dst->ne[0];
const int ne1 = (int) dst->ne[1];
const int ne2 = (int) dst->ne[2];
const int ne3 = (int) dst->ne[3];
const int32_t *params = (const int32_t *) dst->op_params;
int shift0 = params[0];
int shift1 = params[1];
int shift2 = params[2];
int shift3 = params[3];
if ((shift0 | shift1 | shift2 | shift3) == 0) {
const size_t nb = ggml_nbytes(src);
queue *q = ctx.stream();
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
return;
}
auto norm = [](int sh, int n) -> int {
if (n <= 0) return 0;
sh %= n;
if (sh < 0) sh += n;
return sh;
};
shift0 = norm(shift0, ne0);
shift1 = norm(shift1, ne1);
shift2 = norm(shift2, ne2);
shift3 = norm(shift3, ne3);
try {
queue *q = ctx.stream();
const float *src_d = (const float *) src->data;
float *dst_d = (float *) dst->data;
GGML_ASSERT(src_d && dst_d);
kernel_roll_fused_i0_i1(
*q, src_d, dst_d,
ne0, ne1, ne2, ne3,
shift0, shift1, shift2, shift3
);
} catch (const std::exception &e) {
std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
throw;
}
}

View File

@@ -0,0 +1,20 @@
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_ROLL_HPP
#define GGML_SYCL_ROLL_HPP
#include "common.hpp"
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
#endif // GGML_SYCL_ROLL_HPP

View File

@@ -15,7 +15,6 @@
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cfloat>
#include <cstring>
#include <cmath>
@@ -438,7 +437,7 @@ struct llama_model::impl {
llama_mlocks mlock_mmaps;
// contexts where the model tensors metadata is stored as well ass the corresponding buffers:
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
std::vector<std::pair<ggml_context_ptr, std::vector<ggml_backend_buffer_ptr>>> ctxs_bufs;
buft_list_t cpu_buft_list;
std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
@@ -6186,7 +6185,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
ggml_backend_buffer_t buf = nullptr;
std::vector<ggml_backend_buffer_ptr> bufs;
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
// only the mmap region containing the tensors in the model is mapped to the backend buffer
@@ -6199,15 +6198,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
continue;
}
const size_t max_size = ggml_get_max_tensor_size(ctx);
buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
if (buf == nullptr) {
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
}
bufs.emplace_back(buf);
buf_map.emplace(idx, buf);
}
}
else {
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (buf == nullptr) {
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
}
@@ -6217,11 +6217,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
mlock_buf->init (ggml_backend_buffer_get_base(buf));
mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
}
bufs.emplace_back(buf);
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
buf_map.emplace(idx, buf);
}
}
pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), buf);
pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs));
for (auto & buf : buf_map) {
// indicate that this buffer contains weights
@@ -6247,8 +6248,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
// print memory requirements per buffer type
for (auto & [_, buf] : pimpl->ctxs_bufs) {
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
for (auto & [_, bufs] : pimpl->ctxs_bufs) {
for (auto & buf: bufs) {
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n",
__func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
}
}
// populate tensors_by_name
@@ -6300,8 +6304,10 @@ size_t llama_model::n_devices() const {
std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const auto & [_, buf] : pimpl->ctxs_bufs) {
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
for (const auto & [_, bufs] : pimpl->ctxs_bufs) {
for (const auto & buf : bufs) {
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
}
}
return ret;
}

View File

@@ -511,7 +511,7 @@ struct test_result {
};
// Printer classes for different output formats
enum class test_status_t { NOT_SUPPORTED, OK, FAIL };
enum class test_status_t { NOT_SUPPORTED, OK, FAIL, SKIPPED };
struct test_operation_info {
std::string op_name;
@@ -687,6 +687,8 @@ struct printer {
virtual void print_backend_status(const backend_status_info & info) { (void) info; }
virtual void print_overall_summary(const overall_summary_info & info) { (void) info; }
virtual void print_failed_tests(const std::vector<std::string> & failed_tests) { (void) failed_tests; }
};
struct console_printer : public printer {
@@ -804,6 +806,17 @@ struct console_printer : public printer {
}
}
void print_failed_tests(const std::vector<std::string> & failed_tests) override {
if (failed_tests.empty()) {
return;
}
printf("\nFailing tests:\n");
for (const auto & test_name : failed_tests) {
printf(" %s\n", test_name.c_str());
}
}
private:
void print_test_console(const test_result & result) {
printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
@@ -1056,6 +1069,8 @@ struct test_case {
std::vector<ggml_tensor *> sentinels;
std::string current_op_name;
void add_sentinel(ggml_context * ctx) {
if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) {
return;
@@ -1127,7 +1142,10 @@ struct test_case {
}
}
bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) {
test_status_t eval(ggml_backend_t backend1,
ggml_backend_t backend2,
const char * op_names_filter,
printer * output_printer) {
mode = MODE_TEST;
ggml_init_params params = {
@@ -1144,11 +1162,12 @@ struct test_case {
add_sentinel(ctx);
ggml_tensor * out = build_graph(ctx);
std::string current_op_name = op_desc(out);
current_op_name = op_desc(out);
if (!matches_filter(out, op_names_filter)) {
//printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx);
return true;
return test_status_t::SKIPPED;
}
// check if the backends support the ops
@@ -1172,7 +1191,7 @@ struct test_case {
}
ggml_free(ctx);
return true;
return test_status_t::NOT_SUPPORTED;
}
// post-graph sentinel
@@ -1184,7 +1203,7 @@ struct test_case {
if (buf == NULL) {
printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
ggml_free(ctx);
return false;
return test_status_t::FAIL;
}
// build graph
@@ -1289,7 +1308,7 @@ struct test_case {
output_printer->print_test_result(result);
}
return test_passed;
return test_passed ? test_status_t::OK : test_status_t::FAIL;
}
bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
@@ -1306,7 +1325,7 @@ struct test_case {
GGML_ASSERT(ctx);
ggml_tensor * out = build_graph(ctx.get());
std::string current_op_name = op_desc(out);
current_op_name = op_desc(out);
if (!matches_filter(out, op_names_filter)) {
//printf(" %s: skipping\n", op_desc(out).c_str());
return true;
@@ -1435,8 +1454,9 @@ struct test_case {
ggml_context_ptr ctx(ggml_init(params)); // smart ptr
GGML_ASSERT(ctx);
ggml_tensor * out = build_graph(ctx.get());
std::string current_op_name = op_desc(out);
ggml_tensor * out = build_graph(ctx.get());
current_op_name = op_desc(out);
if (!matches_filter(out, op_names_filter)) {
return true;
}
@@ -4712,6 +4732,7 @@ struct test_topk_moe: public test_case {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
}
@@ -7359,16 +7380,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
size_t n_ok = 0;
size_t tests_run = 0;
std::vector<std::string> failed_tests;
for (auto & test : test_cases) {
if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) {
test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer);
if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {
continue;
}
tests_run++;
if (status == test_status_t::OK) {
n_ok++;
} else if (status == test_status_t::FAIL) {
failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")");
}
}
output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false));
output_printer->print_summary(test_summary_info(n_ok, tests_run, false));
output_printer->print_failed_tests(failed_tests);
ggml_backend_free(backend_cpu);
return n_ok == test_cases.size();
return n_ok == tests_run;
}
if (mode == MODE_GRAD) {