Compare commits

...

3 Commits

Author SHA1 Message Date
Shouzheng Liu
fc8ef549e5 metal : enable ggml-alloc (#2627)
* metal: enable ggml-alloc

Make ggml-alloc work with concurrently dispatch.

* style-fix

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-08-16 23:08:28 +03:00
Shouzheng Liu
bf83bff674 metal : matrix-matrix multiplication kernel (#2615)
* metal: matrix-matrix multiplication kernel

This commit removes MPS and uses custom matrix-matrix multiplication
kernels for all quantization types. This commit also adds grouped-query
attention to support llama2 70B.

* metal: fix performance degradation from gqa

Integers are slow on the GPU, and 64-bit divides are extremely slow.
In the context of GQA, we introduce a 64-bit divide that cannot be
optimized out by the compiler, which results in a decrease of ~8% in
inference performance. This commit fixes that issue by calculating a
part of the offset with a 32-bit divide. Naturally, this limits the
size of a single matrix to ~4GB. However, this limitation should
suffice for the near future.

* metal: fix bugs for GQA and perplexity test.

I mixed up ne02 and nb02 in previous commit.
2023-08-16 23:07:04 +03:00
Georgi Gerganov
b5ffb2849d scripts : add helper script to get wikitext 2023-08-15 10:05:25 +03:00
10 changed files with 592 additions and 662 deletions

View File

@@ -296,7 +296,6 @@ if (LLAMA_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
@@ -313,7 +312,6 @@ if (LLAMA_METAL)
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
)
endif()

View File

@@ -283,7 +283,7 @@ endif # LLAMA_CLBLAST
ifdef LLAMA_METAL
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
CXXFLAGS += -DGGML_USE_METAL
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
OBJS += ggml-metal.o
endif # LLAMA_METAL

View File

@@ -14,8 +14,6 @@
with pkgs.darwin.apple_sdk_11_0.frameworks; [
Accelerate
MetalKit
MetalPerformanceShaders
MetalPerformanceShadersGraph
]
else if isAarch32 && isDarwin then
with pkgs.darwin.apple_sdk.frameworks; [

View File

@@ -67,6 +67,8 @@ struct ggml_allocr {
struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
size_t max_size;
bool measure;
int parse_seq[GGML_MAX_NODES];
bool has_parse_seq;
#ifdef GGML_ALLOCATOR_DEBUG
struct ggml_tensor * allocated_tensors[1024];
@@ -229,6 +231,17 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t
alloc->n_free_blocks++;
}
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) {
int pos = 0;
for (int i = 0; i < n; i++) {
if (list[i] != -1) {
alloc->parse_seq[pos] = list[i];
pos++;
}
}
alloc->has_parse_seq = true;
}
void ggml_allocr_reset(struct ggml_allocr * alloc) {
alloc->n_free_blocks = 1;
size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
@@ -248,6 +261,8 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment)
/*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ false,
/*.parse_seq = */ {0},
/*.has_parse_seq = */ false,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ = {0},
#endif
@@ -275,6 +290,8 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
/*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ true,
/*.parse_seq = */ {0},
/*.has_parse_seq = */ false,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ = {0},
#endif
@@ -473,7 +490,13 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
allocate_node(alloc, input);
}
}
for (int i = 0; i < gf->n_nodes; i++) {
for (int ind = 0; ind < gf->n_nodes; ind++) {
int i;
if (alloc->has_parse_seq) {
i = alloc->parse_seq[ind];
} else {
i = ind;
}
struct ggml_tensor * node = gf->nodes[i];
// allocate parents (leafs)

View File

@@ -10,6 +10,10 @@ extern "C" {
GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
// tell the allocator to parse nodes following the order described in the list
// you should call this if your graph are optimized to execute out-of-order
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n);
GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc);

View File

@@ -63,10 +63,13 @@ void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
// try to find operations that can be run concurrently in the graph
// you should run it again if the topology of your graph changes
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
// if the graph has been optimized for concurrently dispatch
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
// if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
// output the concur_list for ggml_alloc
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
// same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel

View File

@@ -5,7 +5,6 @@
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#undef MIN
#undef MAX
@@ -79,6 +78,14 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -110,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->n_buffers = 0;
ctx->concur_list_len = 0;
// determine if we can use MPS
if (MPSSupportsMTLDevice(ctx->device)) {
fprintf(stderr, "%s: using MPS\n", __func__);
} else {
fprintf(stderr, "%s: not using MPS\n", __func__);
GGML_ASSERT(false && "MPS not supported");
}
#if 0
// compile from source string and show compile log
@@ -196,6 +196,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -228,11 +236,12 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
ctx->n_cb = n_cb;
}
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
if (ctx->concur_list_len) {
return true;
}
return false;
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
return ctx->concur_list_len;
}
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
return ctx->concur_list;
}
// finds the Metal buffer that contains the tensor data on the GPU device
@@ -375,7 +384,7 @@ void ggml_metal_get_tensor(
void ggml_metal_graph_find_concurrency(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
struct ggml_cgraph * gf, bool check_mem) {
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
int nodes_unused[GGML_MAX_CONCUR];
@@ -422,7 +431,7 @@ void ggml_metal_graph_find_concurrency(
}
}
}
if (exe_flag) {
if (exe_flag && check_mem) {
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
int64_t data_start = (int64_t) gf->nodes[i]->data;
@@ -506,7 +515,7 @@ void ggml_metal_graph_compute(
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = nil;
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
@@ -515,10 +524,6 @@ void ggml_metal_graph_compute(
const int i = has_concur ? ctx->concur_list[ind] : ind;
if (i == -1) {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
continue;
}
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue;
}
@@ -592,10 +597,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_ADD:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -613,10 +614,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_MUL:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -634,10 +631,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_SCALE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const float scale = *(const float *) src1->data;
[encoder setComputePipelineState:ctx->pipeline_scale];
@@ -653,10 +646,6 @@ void ggml_metal_graph_compute(
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_SILU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -667,10 +656,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_UNARY_OP_RELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -681,10 +666,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_UNARY_OP_GELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_gelu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -701,10 +682,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_SOFT_MAX:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int nth = 32;
[encoder setComputePipelineState:ctx->pipeline_soft_max];
@@ -719,10 +696,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_DIAG_MASK_INF:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int n_past = ((int32_t *)(dst->op_params))[0];
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
@@ -740,53 +713,43 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne00 == ne10);
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
uint gqa = ne12/ne02;
GGML_ASSERT(ne03 == ne13);
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00%32 == 0 &&
ne11 > 1) {
switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
}
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
// for F32 x F32 we use MPS
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
initWithDevice:ctx->device transposeLeft:false transposeRight:true
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
// we need to do ne12 multiplications
// TODO: is there a way to do this in parallel - currently very slow ..
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
for (int64_t i02 = 0; i02 < ne12; ++i02) {
size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2;
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
}
} else {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
else {
int nth0 = 32;
int nth1 = 1;
@@ -885,23 +848,24 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif
}
else if (src0t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -910,10 +874,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
@@ -939,10 +899,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_RMS_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
@@ -962,10 +918,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const float eps = 1e-5f;
const int nth = 256;
@@ -984,10 +936,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_ALIBI:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
GGML_ASSERT((src0t == GGML_TYPE_F32));
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
@@ -1027,10 +975,6 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_ROPE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
@@ -1071,10 +1015,6 @@ void ggml_metal_graph_compute(
case GGML_OP_CPY:
case GGML_OP_CONT:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int nth = 32;
switch (src0t) {

File diff suppressed because it is too large Load Diff

View File

@@ -63,7 +63,7 @@ static void llama_log_callback_default(llama_log_level level, const char * text,
#define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__)
#if !defined(GGML_USE_CUBLAS) && !defined(GGML_USE_METAL)
#if !defined(GGML_USE_CUBLAS)
#include "ggml-alloc.h"
#define LLAMA_USE_ALLOCATOR
#else
@@ -1845,11 +1845,7 @@ static bool llama_eval_internal(
#endif
#ifdef GGML_USE_METAL
if (lctx.ctx_metal && N == 1) {
// TODO: disabled until #2413 is resolved
//if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
// ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf);
//}
if (lctx.ctx_metal) {
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
ggml_metal_graph_compute(lctx.ctx_metal, gf);
ggml_metal_get_tensor (lctx.ctx_metal, res);
@@ -1857,22 +1853,6 @@ static bool llama_eval_internal(
ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
}
} else {
// IMPORTANT:
// Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla
// ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX
// coprocessor.
//
// When we implement Matrix x Matrix Metal multiplication, we can avoid this branch.
// But for now, we have focused only on Matrix x Vector Metal multiplication.
//
// TODO: avoid these syncs via shared memory (ref #1696)
//
if (lctx.ctx_metal) {
// We need to sync the GPU KV cache with the CPU KV cache
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.k);
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
}
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
}
#else
@@ -3303,7 +3283,18 @@ struct llama_context * llama_new_context_with_model(
int n_past = hparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past);
#ifdef GGML_USE_METAL
if (params.n_gpu_layers > 0) {
ctx->ctx_metal = ggml_metal_init(1);
if (!ctx->ctx_metal) {
LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__);
llama_free(ctx);
return NULL;
}
ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false);
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
}
#endif
// measure memory requirements for the graph
size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment;
@@ -3321,6 +3312,11 @@ struct llama_context * llama_new_context_with_model(
ctx->buf_alloc.resize(alloc_size);
ctx->alloc = ggml_allocr_new(ctx->buf_alloc.addr, ctx->buf_alloc.size, tensor_alignment);
#ifdef GGML_USE_METAL
if (ctx->ctx_metal) {
ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
}
#endif
}
#else
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
@@ -3335,13 +3331,6 @@ struct llama_context * llama_new_context_with_model(
#ifdef GGML_USE_METAL
if (params.n_gpu_layers > 0) {
// this allocates all Metal resources and memory buffers
ctx->ctx_metal = ggml_metal_init(1);
if (!ctx->ctx_metal) {
LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__);
llama_free(ctx);
return NULL;
}
void * data_ptr = NULL;
size_t data_size = 0;
@@ -3370,8 +3359,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.addr, ctx->kv_self.buf.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.addr, ctx->buf_alloc.size, 0));
#undef LLAMA_METAL_CHECK_BUF
}
#endif

View File

@@ -0,0 +1,3 @@
#!/bin/bash
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip