mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-17 22:44:07 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75cbdd3fce | ||
|
|
2b9bd9bf4e | ||
|
|
59fc1ec8e8 | ||
|
|
75d33b9302 | ||
|
|
3470a5c891 | ||
|
|
bd562fe4f7 | ||
|
|
bbac6a26b2 | ||
|
|
73a48c9790 |
@@ -742,6 +742,12 @@ class TextModel(ModelBase):
|
||||
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
logger.info(f"gguf: experts used count = {n_experts_used}")
|
||||
if (n_expert_groups := self.hparams.get("n_group")) is not None:
|
||||
self.gguf_writer.add_expert_group_count(n_expert_groups)
|
||||
logger.info(f"gguf: expert groups count = {n_expert_groups}")
|
||||
if (n_group_used := self.hparams.get("topk_group")) is not None:
|
||||
self.gguf_writer.add_expert_group_used_count(n_group_used)
|
||||
logger.info(f"gguf: expert groups used count = {n_group_used}")
|
||||
|
||||
if (head_dim := self.hparams.get("head_dim")) is not None:
|
||||
self.gguf_writer.add_key_length(head_dim)
|
||||
@@ -8233,8 +8239,6 @@ class BailingMoeV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_group_count(hparams["n_group"])
|
||||
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["score_function"] == "sigmoid":
|
||||
|
||||
@@ -226,16 +226,23 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
|
||||
}
|
||||
|
||||
if (best_fit_block == -1) {
|
||||
// no suitable block found, try the last block (this will grow a chunks size)
|
||||
// no suitable block found, try the last block (this may grow a chunks size)
|
||||
int64_t best_reuse = INT64_MIN;
|
||||
for (int c = 0; c < alloc->n_chunks; ++c) {
|
||||
struct tallocr_chunk * chunk = alloc->chunks[c];
|
||||
if (chunk->n_free_blocks > 0) {
|
||||
struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];
|
||||
max_avail = MAX(max_avail, block->size);
|
||||
if (block->size >= size) {
|
||||
int64_t reuse_factor = chunk->max_size - block->offset - size;
|
||||
// reuse_factor < 0 : amount of extra memory that needs to be allocated
|
||||
// reuse_factor = 0 : allocated free space exactly matches tensor size
|
||||
// reuse_factor > 0 : superfluous memory that will remain unused
|
||||
bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse;
|
||||
bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse;
|
||||
if (block->size >= size && (better_reuse || better_fit)) {
|
||||
best_fit_chunk = c;
|
||||
best_fit_block = chunk->n_free_blocks - 1;
|
||||
break;
|
||||
best_reuse = reuse_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -268,7 +275,7 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
add_allocated_tensor(alloc, addr, tensor);
|
||||
size_t cur_max = addr.offset + size;
|
||||
if (cur_max > alloc->max_size[addr.chunk]) {
|
||||
if (cur_max > chunk->max_size) {
|
||||
// sort allocated_tensors by chunk/offset
|
||||
for (int i = 0; i < 1024; i++) {
|
||||
for (int j = i + 1; j < 1024; j++) {
|
||||
|
||||
@@ -112,6 +112,30 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const src_t * x = (const src_t *) cx;
|
||||
dst_t * dst = (dst_t *) cdst;
|
||||
|
||||
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_flt_contiguous_cuda(
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_flt_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
@@ -285,7 +309,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
char * src0_ddc = (char *) src0->data;
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
||||
|
||||
if (src0->type == src1->type && contiguous_srcs) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
||||
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
|
||||
@@ -296,11 +322,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
@@ -327,21 +361,45 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
|
||||
@@ -1957,8 +1957,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
||||
|
||||
size_t src1_stride_size = sizeof(cuda_t);
|
||||
|
||||
dim3 block_dims(ne13, ne12);
|
||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||
const int threads_x = 16;
|
||||
const int threads_y = 16;
|
||||
dim3 block_dims(threads_x, threads_y);
|
||||
|
||||
dim3 grid_dims(
|
||||
(ne13 + threads_x - 1) / threads_x,
|
||||
(ne12 + threads_y - 1) / threads_y
|
||||
);
|
||||
k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
|
||||
src0_ptr, src1_ptr, dst_t,
|
||||
ptrs_src.get(), ptrs_dst.get(),
|
||||
ne12, ne13,
|
||||
@@ -2969,7 +2976,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
if (ops.size() == topk_moe_ops_with_norm.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx+8];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
return true;
|
||||
@@ -2979,7 +2986,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
if (ops.size() == topk_moe_ops.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx+4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
return true;
|
||||
}
|
||||
@@ -3118,17 +3125,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
if (!disable_fusion) {
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i+8];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
||||
ggml_tensor * weights = cgraph->nodes[i + 9];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
|
||||
ggml_tensor * clamp = cgraph->nodes[i + 7];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
|
||||
/*delayed softmax*/ false);
|
||||
i += 8;
|
||||
/*delayed softmax*/ false, clamp);
|
||||
i += 9;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i+4];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
||||
ggml_tensor * weights = cgraph->nodes[i + 4];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
|
||||
/*delayed softmax*/ false);
|
||||
i += 4;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "ggml.h"
|
||||
#include "topk-moe.cuh"
|
||||
|
||||
#include <cmath>
|
||||
#include <initializer_list>
|
||||
|
||||
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||
@@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
const int n_rows,
|
||||
const int n_expert_used) {
|
||||
const int n_expert_used,
|
||||
const float clamp_val) {
|
||||
const int row = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
@@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
|
||||
if constexpr (with_norm) {
|
||||
wt_sum = warp_reduce_sum(wt_sum);
|
||||
wt_sum = max(wt_sum, clamp_val);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
@@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
||||
weights[idx] = output_weights[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (!with_norm) {
|
||||
GGML_UNUSED(clamp_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool with_norm, bool delayed_softmax = false>
|
||||
@@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
int32_t * ids,
|
||||
const int n_rows,
|
||||
const int n_expert,
|
||||
const int n_expert_used) {
|
||||
const int n_expert_used,
|
||||
const float clamp_val) {
|
||||
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
|
||||
|
||||
const int rows_per_block = 4;
|
||||
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
||||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||
@@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
switch (n_expert) {
|
||||
case 1:
|
||||
topk_moe_cuda<1, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 2:
|
||||
topk_moe_cuda<2, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 4:
|
||||
topk_moe_cuda<4, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 8:
|
||||
topk_moe_cuda<8, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 16:
|
||||
topk_moe_cuda<16, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 32:
|
||||
topk_moe_cuda<32, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 64:
|
||||
topk_moe_cuda<64, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 128:
|
||||
topk_moe_cuda<128, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 256:
|
||||
topk_moe_cuda<256, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 512:
|
||||
topk_moe_cuda<512, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "fatal error");
|
||||
@@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax) {
|
||||
const bool delayed_softmax,
|
||||
ggml_tensor * clamp) {
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
@@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
|
||||
const int n_expert_used = weights->ne[1];
|
||||
|
||||
float clamp_val = -INFINITY;
|
||||
if (with_norm) {
|
||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
if (clamp) {
|
||||
clamp_val = ggml_get_op_params_f32(clamp, 0);
|
||||
}
|
||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
|
||||
} else {
|
||||
GGML_ASSERT(clamp == nullptr);
|
||||
if (delayed_softmax) {
|
||||
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
||||
clamp_val);
|
||||
} else {
|
||||
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
||||
clamp_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
@@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
|
||||
return false;
|
||||
}
|
||||
|
||||
if (clamp) {
|
||||
if (clamp->op != GGML_OP_CLAMP) {
|
||||
return false;
|
||||
}
|
||||
float max_val = ggml_get_op_params_f32(clamp, 1);
|
||||
|
||||
if (max_val != INFINITY) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
|
||||
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
||||
GGML_OP_RESHAPE };
|
||||
|
||||
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
|
||||
@@ -8,8 +8,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax = false);
|
||||
const bool delayed_softmax = false,
|
||||
ggml_tensor * weight_clamp = nullptr);
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
#include "pad.hpp"
|
||||
#include "quantize.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "roll.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "set_rows.hpp"
|
||||
#include "softmax.hpp"
|
||||
|
||||
@@ -48,6 +48,7 @@
|
||||
#include "ggml-sycl/set.hpp"
|
||||
#include "ggml-sycl/sycl_hw.hpp"
|
||||
#include "ggml-sycl/getrows.hpp"
|
||||
#include "ggml-sycl/repeat_back.hpp"
|
||||
#include "ggml-sycl/quantize.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
@@ -2615,6 +2616,10 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_repeat_back(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
@@ -3679,6 +3684,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_REPEAT:
|
||||
ggml_sycl_repeat(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
ggml_sycl_repeat_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_sycl_get_rows(ctx, dst);
|
||||
break;
|
||||
@@ -3913,6 +3921,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROLL:
|
||||
ggml_sycl_roll(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARANGE:
|
||||
ggml_sycl_arange(ctx, dst);
|
||||
break;
|
||||
@@ -4516,6 +4527,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
}
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type == GGML_TYPE_F32;
|
||||
}
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_NONE:
|
||||
@@ -4586,6 +4602,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
return true;
|
||||
case GGML_OP_ROLL:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ARANGE:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
default:
|
||||
|
||||
56
ggml/src/ggml-sycl/repeat_back.cpp
Normal file
56
ggml/src/ggml-sycl/repeat_back.cpp
Normal file
@@ -0,0 +1,56 @@
|
||||
#include "repeat_back.hpp"
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float * src0_dd = (const float *) dst->src[0]->data;
|
||||
float * dst_dd = (float *) dst->data;
|
||||
|
||||
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
|
||||
const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
|
||||
ne03 = dst->src[0]->ne[3];
|
||||
|
||||
const int nr0 = (int) (ne00 / ne0);
|
||||
const int nr1 = (int) (ne01 / ne1);
|
||||
const int nr2 = (int) (ne02 / ne2);
|
||||
const int nr3 = (int) (ne03 / ne3);
|
||||
|
||||
const size_t total = ne0 * ne1 * ne2 * ne3;
|
||||
const int BLOCK_SIZE = 256;
|
||||
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
queue_ptr stream = ctx.stream();
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
const size_t i = item_ct1.get_global_linear_id();
|
||||
if (i >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i0 = i % ne0;
|
||||
const int i1 = (i / ne0) % ne1;
|
||||
const int i2 = (i / (ne0 * ne1)) % ne2;
|
||||
const int i3 = i / (ne0 * ne1 * ne2);
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
for (int j3 = 0; j3 < nr3; ++j3) {
|
||||
for (int j2 = 0; j2 < nr2; ++j2) {
|
||||
for (int j1 = 0; j1 < nr1; ++j1) {
|
||||
for (int j0 = 0; j0 < nr0; ++j0) {
|
||||
acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
|
||||
(i3 + j3 * ne3) * ne00 * ne01 * ne02];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst_dd[i] = acc;
|
||||
});
|
||||
}
|
||||
8
ggml/src/ggml-sycl/repeat_back.hpp
Normal file
8
ggml/src/ggml-sycl/repeat_back.hpp
Normal file
@@ -0,0 +1,8 @@
|
||||
#ifndef GGML_SYCL_REPEAT_BACK_HPP
|
||||
#define GGML_SYCL_REPEAT_BACK_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_REPEAT_BACK_HPP
|
||||
122
ggml/src/ggml-sycl/roll.cpp
Normal file
122
ggml/src/ggml-sycl/roll.cpp
Normal file
@@ -0,0 +1,122 @@
|
||||
#include "roll.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
using namespace sycl;
|
||||
|
||||
static inline int wrap_add(int i, int shift, int n) {
|
||||
|
||||
int s = i + shift;
|
||||
return (s >= n) ? (s - n) : s;
|
||||
}
|
||||
|
||||
static void kernel_roll_fused_i0_i1(
|
||||
queue &q,
|
||||
const float *src_d,
|
||||
float *dst_d,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int sh0, int sh1, int sh2, int sh3)
|
||||
{
|
||||
if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;
|
||||
|
||||
|
||||
const int stride1 = ne0;
|
||||
const int stride2 = ne0 * ne1;
|
||||
const int stride3 = ne0 * ne1 * ne2;
|
||||
|
||||
|
||||
const int shNe0 = (ne0 - sh0) % ne0;
|
||||
const int shNe1 = (ne1 - sh1) % ne1;
|
||||
const int shNe2 = (ne2 - sh2) % ne2;
|
||||
const int shNe3 = (ne3 - sh3) % ne3;
|
||||
|
||||
|
||||
const size_t g0 = (size_t) ne3;
|
||||
const size_t g1 = (size_t) ne2;
|
||||
const size_t g2 = (size_t) (ne1 * ne0);
|
||||
|
||||
const range<3> global{ g0, g1, g2 };
|
||||
|
||||
q.submit([&](handler &h) {
|
||||
h.parallel_for(global, [=](id<3> idx) {
|
||||
const int i3 = (int) idx[0];
|
||||
const int i2 = (int) idx[1];
|
||||
|
||||
const int fused = (int) idx[2];
|
||||
const int i1 = fused / ne0;
|
||||
const int i0 = fused - i1 * ne0; // fused % ne0
|
||||
|
||||
|
||||
const int idx_dst = i0
|
||||
+ i1 * stride1
|
||||
+ i2 * stride2
|
||||
+ i3 * stride3;
|
||||
|
||||
|
||||
const int s0 = wrap_add(i0, shNe0, ne0);
|
||||
const int s1 = wrap_add(i1, shNe1, ne1);
|
||||
const int s2 = wrap_add(i2, shNe2, ne2);
|
||||
const int s3 = wrap_add(i3, shNe3, ne3);
|
||||
|
||||
const int idx_src = s0
|
||||
+ s1 * stride1
|
||||
+ s2 * stride2
|
||||
+ s3 * stride3;
|
||||
|
||||
dst_d[idx_dst] = src_d[idx_src];
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const ggml_tensor *src = dst->src[0];
|
||||
GGML_ASSERT(src && src->type == GGML_TYPE_F32);
|
||||
|
||||
const int ne0 = (int) dst->ne[0];
|
||||
const int ne1 = (int) dst->ne[1];
|
||||
const int ne2 = (int) dst->ne[2];
|
||||
const int ne3 = (int) dst->ne[3];
|
||||
|
||||
const int32_t *params = (const int32_t *) dst->op_params;
|
||||
int shift0 = params[0];
|
||||
int shift1 = params[1];
|
||||
int shift2 = params[2];
|
||||
int shift3 = params[3];
|
||||
|
||||
|
||||
if ((shift0 | shift1 | shift2 | shift3) == 0) {
|
||||
const size_t nb = ggml_nbytes(src);
|
||||
queue *q = ctx.stream();
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
|
||||
return;
|
||||
}
|
||||
|
||||
auto norm = [](int sh, int n) -> int {
|
||||
if (n <= 0) return 0;
|
||||
sh %= n;
|
||||
if (sh < 0) sh += n;
|
||||
return sh;
|
||||
};
|
||||
shift0 = norm(shift0, ne0);
|
||||
shift1 = norm(shift1, ne1);
|
||||
shift2 = norm(shift2, ne2);
|
||||
shift3 = norm(shift3, ne3);
|
||||
|
||||
try {
|
||||
queue *q = ctx.stream();
|
||||
|
||||
const float *src_d = (const float *) src->data;
|
||||
float *dst_d = (float *) dst->data;
|
||||
GGML_ASSERT(src_d && dst_d);
|
||||
|
||||
kernel_roll_fused_i0_i1(
|
||||
*q, src_d, dst_d,
|
||||
ne0, ne1, ne2, ne3,
|
||||
shift0, shift1, shift2, shift3
|
||||
);
|
||||
} catch (const std::exception &e) {
|
||||
std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
|
||||
throw;
|
||||
}
|
||||
}
|
||||
20
ggml/src/ggml-sycl/roll.hpp
Normal file
20
ggml/src/ggml-sycl/roll.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_ROLL_HPP
|
||||
#define GGML_SYCL_ROLL_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
||||
|
||||
#endif // GGML_SYCL_ROLL_HPP
|
||||
@@ -6369,6 +6369,8 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
|
||||
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
|
||||
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
|
||||
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
|
||||
LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
|
||||
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
|
||||
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
|
||||
@@ -6469,8 +6471,6 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
|
||||
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
|
||||
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
|
||||
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
||||
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||
|
||||
@@ -511,7 +511,7 @@ struct test_result {
|
||||
};
|
||||
|
||||
// Printer classes for different output formats
|
||||
enum class test_status_t { NOT_SUPPORTED, OK, FAIL };
|
||||
enum class test_status_t { NOT_SUPPORTED, OK, FAIL, SKIPPED };
|
||||
|
||||
struct test_operation_info {
|
||||
std::string op_name;
|
||||
@@ -687,6 +687,8 @@ struct printer {
|
||||
virtual void print_backend_status(const backend_status_info & info) { (void) info; }
|
||||
|
||||
virtual void print_overall_summary(const overall_summary_info & info) { (void) info; }
|
||||
|
||||
virtual void print_failed_tests(const std::vector<std::string> & failed_tests) { (void) failed_tests; }
|
||||
};
|
||||
|
||||
struct console_printer : public printer {
|
||||
@@ -804,6 +806,17 @@ struct console_printer : public printer {
|
||||
}
|
||||
}
|
||||
|
||||
void print_failed_tests(const std::vector<std::string> & failed_tests) override {
|
||||
if (failed_tests.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
printf("\nFailing tests:\n");
|
||||
for (const auto & test_name : failed_tests) {
|
||||
printf(" %s\n", test_name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void print_test_console(const test_result & result) {
|
||||
printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
|
||||
@@ -1056,6 +1069,8 @@ struct test_case {
|
||||
|
||||
std::vector<ggml_tensor *> sentinels;
|
||||
|
||||
std::string current_op_name;
|
||||
|
||||
void add_sentinel(ggml_context * ctx) {
|
||||
if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) {
|
||||
return;
|
||||
@@ -1127,7 +1142,10 @@ struct test_case {
|
||||
}
|
||||
}
|
||||
|
||||
bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) {
|
||||
test_status_t eval(ggml_backend_t backend1,
|
||||
ggml_backend_t backend2,
|
||||
const char * op_names_filter,
|
||||
printer * output_printer) {
|
||||
mode = MODE_TEST;
|
||||
|
||||
ggml_init_params params = {
|
||||
@@ -1144,11 +1162,12 @@ struct test_case {
|
||||
add_sentinel(ctx);
|
||||
|
||||
ggml_tensor * out = build_graph(ctx);
|
||||
std::string current_op_name = op_desc(out);
|
||||
current_op_name = op_desc(out);
|
||||
|
||||
if (!matches_filter(out, op_names_filter)) {
|
||||
//printf(" %s: skipping\n", op_desc(out).c_str());
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
return test_status_t::SKIPPED;
|
||||
}
|
||||
|
||||
// check if the backends support the ops
|
||||
@@ -1172,7 +1191,7 @@ struct test_case {
|
||||
}
|
||||
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
return test_status_t::NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
// post-graph sentinel
|
||||
@@ -1184,7 +1203,7 @@ struct test_case {
|
||||
if (buf == NULL) {
|
||||
printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
return test_status_t::FAIL;
|
||||
}
|
||||
|
||||
// build graph
|
||||
@@ -1289,7 +1308,7 @@ struct test_case {
|
||||
output_printer->print_test_result(result);
|
||||
}
|
||||
|
||||
return test_passed;
|
||||
return test_passed ? test_status_t::OK : test_status_t::FAIL;
|
||||
}
|
||||
|
||||
bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
|
||||
@@ -1306,7 +1325,7 @@ struct test_case {
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
ggml_tensor * out = build_graph(ctx.get());
|
||||
std::string current_op_name = op_desc(out);
|
||||
current_op_name = op_desc(out);
|
||||
if (!matches_filter(out, op_names_filter)) {
|
||||
//printf(" %s: skipping\n", op_desc(out).c_str());
|
||||
return true;
|
||||
@@ -1435,8 +1454,9 @@ struct test_case {
|
||||
ggml_context_ptr ctx(ggml_init(params)); // smart ptr
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
ggml_tensor * out = build_graph(ctx.get());
|
||||
std::string current_op_name = op_desc(out);
|
||||
ggml_tensor * out = build_graph(ctx.get());
|
||||
current_op_name = op_desc(out);
|
||||
|
||||
if (!matches_filter(out, op_names_filter)) {
|
||||
return true;
|
||||
}
|
||||
@@ -4712,6 +4732,7 @@ struct test_topk_moe: public test_case {
|
||||
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
|
||||
|
||||
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
|
||||
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
|
||||
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
|
||||
}
|
||||
@@ -6697,6 +6718,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
|
||||
|
||||
// test cases with large batch size
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {1536, 1}, {1, 1}));
|
||||
}
|
||||
}
|
||||
for (ggml_type type_a : other_types) {
|
||||
@@ -7356,16 +7380,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
}
|
||||
|
||||
size_t n_ok = 0;
|
||||
size_t tests_run = 0;
|
||||
std::vector<std::string> failed_tests;
|
||||
for (auto & test : test_cases) {
|
||||
if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) {
|
||||
test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer);
|
||||
if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {
|
||||
continue;
|
||||
}
|
||||
tests_run++;
|
||||
if (status == test_status_t::OK) {
|
||||
n_ok++;
|
||||
} else if (status == test_status_t::FAIL) {
|
||||
failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")");
|
||||
}
|
||||
}
|
||||
output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false));
|
||||
output_printer->print_summary(test_summary_info(n_ok, tests_run, false));
|
||||
output_printer->print_failed_tests(failed_tests);
|
||||
|
||||
ggml_backend_free(backend_cpu);
|
||||
|
||||
return n_ok == test_cases.size();
|
||||
return n_ok == tests_run;
|
||||
}
|
||||
|
||||
if (mode == MODE_GRAD) {
|
||||
|
||||
Reference in New Issue
Block a user