mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-08 18:14:07 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f71f40a284 | ||
|
|
d30cb5a7fa | ||
|
|
6c35981a64 | ||
|
|
8b5e19aea6 | ||
|
|
60aea028b5 | ||
|
|
9c55e5c5c2 | ||
|
|
33d7aed4a8 | ||
|
|
6a2bc8bfb7 | ||
|
|
e3a7cf6c5b | ||
|
|
518329b2d4 | ||
|
|
2f5a4e1e09 | ||
|
|
4f41ee11d6 | ||
|
|
3e0be1cace | ||
|
|
6aa892ec2a | ||
|
|
aea9f8b4e7 | ||
|
|
06c1e4abc1 |
@@ -1,4 +1,4 @@
|
||||
ARG ONEAPI_VERSION=2025.0.0-0-devel-ubuntu22.04
|
||||
ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04
|
||||
|
||||
## Build Image
|
||||
|
||||
|
||||
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -899,7 +899,7 @@ jobs:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/b380d914-366b-4b77-a74a-05e3c38b3514/intel-oneapi-base-toolkit-2025.0.0.882_offline.exe
|
||||
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe
|
||||
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel
|
||||
ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI"
|
||||
steps:
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -448,7 +448,7 @@ jobs:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/b380d914-366b-4b77-a74a-05e3c38b3514/intel-oneapi-base-toolkit-2025.0.0.882_offline.exe
|
||||
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe
|
||||
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel
|
||||
ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI"
|
||||
steps:
|
||||
|
||||
@@ -572,4 +572,11 @@ automatically. For example:
|
||||
$ echo "source ~/.llama-completion.bash" >> ~/.bashrc
|
||||
```
|
||||
|
||||
## References
|
||||
## Dependencies
|
||||
|
||||
- [yhirose/cpp-httplib](https://github.com/yhirose/cpp-httplib) - Single-header HTTP server, used by `llama-server` - MIT license
|
||||
- [stb-image](https://github.com/nothings/stb) - Single-header image format decoder, used by multimodal subsystem - Public domain
|
||||
- [nlohmann/json](https://github.com/nlohmann/json) - Single-header JSON library, used by various tools/examples - MIT License
|
||||
- [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License
|
||||
- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License
|
||||
- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html)
|
||||
|
||||
@@ -121,8 +121,8 @@ if (LLAMA_LLGUIDANCE)
|
||||
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
# v0.7.19 (+ fancy-regex build fix):
|
||||
GIT_TAG b59f98f85269892a7de3d3641ad155366f13daa6
|
||||
# v0.7.20 (+ fix to build on GCC 15):
|
||||
GIT_TAG b5b8b64dba11c4e4ee6b1d1450d3a3ae279891e8
|
||||
PREFIX ${CMAKE_BINARY_DIR}/llguidance
|
||||
SOURCE_DIR ${LLGUIDANCE_SRC}
|
||||
BUILD_IN_SOURCE TRUE
|
||||
|
||||
@@ -2585,7 +2585,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, int value) {
|
||||
params.n_junk = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_PASSKEY}));
|
||||
).set_examples({LLAMA_EXAMPLE_PASSKEY, LLAMA_EXAMPLE_PARALLEL}));
|
||||
add_opt(common_arg(
|
||||
{"--pos"}, "N",
|
||||
string_format("position of the passkey in the junk text (default: %d)", params.i_pos),
|
||||
@@ -2648,7 +2648,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.is_pp_shared = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_BENCH}));
|
||||
).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
|
||||
add_opt(common_arg(
|
||||
{"-npp"}, "n0,n1,...",
|
||||
"number of prompt tokens",
|
||||
@@ -2880,6 +2880,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.chat_template = read_file(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"--no-prefill-assistant"},
|
||||
string_format(
|
||||
"whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)\n"
|
||||
"when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled\n"
|
||||
),
|
||||
[](common_params & params) {
|
||||
params.prefill_assistant = false;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_PREFILL_ASSISTANT"));
|
||||
add_opt(common_arg(
|
||||
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
|
||||
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
|
||||
|
||||
@@ -368,6 +368,7 @@ struct common_params {
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
|
||||
|
||||
@@ -1,3 +1,14 @@
|
||||
# llama.cpp/example/parallel
|
||||
|
||||
Simplified simulation of serving incoming requests in parallel
|
||||
|
||||
## Example
|
||||
|
||||
Generate 128 client requests (`-ns 128`), simulating 8 concurrent clients (`-np 8`). The system prompt is shared (`-pps`), meaning that it is computed once at the start. The client requests consist of 10 junk questions (`-j 10`) followed by the actual question.
|
||||
|
||||
```bash
|
||||
llama-parallel -m model.gguf -np 8 -ns 128 --top-k 1 -pps --junk 10 -c 16384
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> It's recommended to use base models with this example. Instruction tuned models might not be able to properly follow the custom chat template specified here, so the results might not be as expected.
|
||||
|
||||
@@ -34,11 +34,61 @@ static std::string k_system =
|
||||
R"(Transcript of a never ending dialog, where the User interacts with an Assistant.
|
||||
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
|
||||
|
||||
User: Recommend a nice restaurant in the area.
|
||||
Assistant: I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
|
||||
User: Who is Richard Feynman?
|
||||
Assistant: Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
|
||||
User:)";
|
||||
User:
|
||||
Recommend a nice restaurant in the area.
|
||||
Assistant:
|
||||
I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
|
||||
User:
|
||||
Who is Richard Feynman?
|
||||
Assistant:
|
||||
Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
|
||||
)";
|
||||
|
||||
static std::vector<std::string> k_questions = {
|
||||
"What is the tallest mountain in the world?",
|
||||
"Who was the first person to win two Nobel Prizes?",
|
||||
"Which country invented paper?",
|
||||
"What organ is primarily responsible for pumping blood throughout the body?",
|
||||
"Which planet is known for its prominent ring system?",
|
||||
"Who directed the movie 'Inception'?",
|
||||
"What is the freezing point of water in Fahrenheit?",
|
||||
"Which animal is known to have the longest lifespan?",
|
||||
"What language has the most native speakers worldwide?",
|
||||
"What is the capital city of Canada?",
|
||||
"Who is credited with inventing the World Wide Web?",
|
||||
"Which metal is liquid at room temperature?",
|
||||
"What is the term for an animal that eats both plants and meat?",
|
||||
"Who painted 'The Starry Night'?",
|
||||
"What gas do humans exhale that plants use for photosynthesis?",
|
||||
"What year did World War II end?",
|
||||
"Which continent has the most countries?",
|
||||
"Who wrote the novel 'Frankenstein'?",
|
||||
"What does DNA stand for?",
|
||||
"What is the main ingredient in traditional Japanese miso soup?"
|
||||
};
|
||||
|
||||
static std::vector<std::string> k_answers = {
|
||||
"The tallest mountain in the world is Mount Everest.",
|
||||
"Marie Curie was the first person to win two Nobel Prizes.",
|
||||
"Paper was invented in China.",
|
||||
"The heart is the organ responsible for pumping blood.",
|
||||
"Saturn is known for its prominent ring system.",
|
||||
"Christopher Nolan directed the movie 'Inception'.",
|
||||
"The freezing point of water in Fahrenheit is 32°F.",
|
||||
"The bowhead whale is known to have the longest lifespan among mammals.",
|
||||
"Mandarin Chinese has the most native speakers in the world.",
|
||||
"The capital city of Canada is Ottawa.",
|
||||
"Tim Berners-Lee is credited with inventing the World Wide Web.",
|
||||
"Mercury is the metal that is liquid at room temperature.",
|
||||
"An animal that eats both plants and meat is called an omnivore.",
|
||||
"'The Starry Night' was painted by Vincent van Gogh.",
|
||||
"Humans exhale carbon dioxide, which plants use in photosynthesis.",
|
||||
"World War II ended in 1945.",
|
||||
"Africa is the continent with the most countries.",
|
||||
"The novel 'Frankenstein' was written by Mary Shelley.",
|
||||
"DNA stands for Deoxyribonucleic Acid.",
|
||||
"The main ingredient in traditional Japanese miso soup is fermented soybean paste."
|
||||
};
|
||||
|
||||
static std::vector<std::string> k_prompts = {
|
||||
"What is the meaning of life?",
|
||||
@@ -49,7 +99,7 @@ static std::vector<std::string> k_prompts = {
|
||||
"What is the best way to learn a new language?",
|
||||
"How to get a job at Google?",
|
||||
"If you could have any superpower, what would it be?",
|
||||
"I want to learn how to play the piano.",
|
||||
"I want to learn how to play the piano. What would be the best way to do it?",
|
||||
};
|
||||
|
||||
struct client {
|
||||
@@ -68,6 +118,7 @@ struct client {
|
||||
int64_t t_start_prompt;
|
||||
int64_t t_start_gen;
|
||||
|
||||
int32_t n_past = 0;
|
||||
int32_t n_prompt = 0;
|
||||
int32_t n_decoded = 0;
|
||||
int32_t i_batch = -1;
|
||||
@@ -107,6 +158,7 @@ int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.n_predict = 128;
|
||||
params.n_junk = 0;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
|
||||
return 1;
|
||||
@@ -128,6 +180,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const bool dump_kv_cache = params.dump_kv_cache;
|
||||
|
||||
// is the system prompt shared in the cache
|
||||
const bool is_sp_shared = params.is_pp_shared;
|
||||
|
||||
// extra text to insert in each client's prompt in order to make it larger
|
||||
const int32_t n_junk = params.n_junk;
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
@@ -169,6 +227,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
|
||||
tokens_system = common_tokenize(ctx, k_system, true);
|
||||
const int32_t n_tokens_system = tokens_system.size();
|
||||
|
||||
@@ -190,7 +249,7 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
|
||||
LOG_INF("\n");
|
||||
|
||||
{
|
||||
if (is_sp_shared) {
|
||||
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
|
||||
|
||||
for (int32_t i = 0; i < n_tokens_system; ++i) {
|
||||
@@ -228,7 +287,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
client.i_batch = batch.n_tokens;
|
||||
|
||||
common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
|
||||
common_batch_add(batch, client.sampled, client.n_past++, { client.id + 1 }, true);
|
||||
|
||||
client.n_decoded += 1;
|
||||
}
|
||||
@@ -254,9 +313,23 @@ int main(int argc, char ** argv) {
|
||||
client.t_start_gen = 0;
|
||||
|
||||
client.input = k_prompts[rand() % k_prompts.size()];
|
||||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
|
||||
// construct the prompt:
|
||||
// [system prompt] + [junk] + [user prompt]
|
||||
client.n_past = 0;
|
||||
client.prompt = "";
|
||||
if (is_sp_shared) {
|
||||
client.n_past = n_tokens_system;
|
||||
} else {
|
||||
client.prompt += k_system;
|
||||
}
|
||||
for (int i = 0; i < n_junk; ++i) {
|
||||
const int r = rand() % k_questions.size();
|
||||
client.prompt += "User:\n" + k_questions[r] + "\nAssistant:\n " + k_answers[r] + "\n";
|
||||
}
|
||||
client.prompt += "User:\n" + client.input + "\nAssistant:\n";
|
||||
|
||||
common_sampler_reset(client.smpl);
|
||||
|
||||
// do not prepend BOS because we have a system prompt!
|
||||
@@ -264,7 +337,7 @@ int main(int argc, char ** argv) {
|
||||
tokens_prompt = common_tokenize(ctx, client.prompt, false);
|
||||
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
|
||||
common_batch_add(batch, tokens_prompt[i], client.n_past++, { client.id + 1 }, false);
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
@@ -363,10 +436,9 @@ int main(int argc, char ** argv) {
|
||||
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
|
||||
|
||||
if (client.n_decoded > 2 &&
|
||||
(llama_vocab_is_eog(vocab, id) ||
|
||||
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
|
||||
client.response.find("User:") != std::string::npos ||
|
||||
client.response.find('\n') != std::string::npos)) {
|
||||
(llama_vocab_is_eog(vocab, id) ||
|
||||
(params.n_predict > 0 && client.n_decoded >= params.n_predict) ||
|
||||
client.response.find("User:") != std::string::npos)) {
|
||||
// basic reverse prompt
|
||||
const size_t pos = client.response.find("User:");
|
||||
if (pos != std::string::npos) {
|
||||
|
||||
@@ -84,13 +84,13 @@ int main(int argc, char ** argv) {
|
||||
model_params.n_gpu_layers = ngl;
|
||||
|
||||
llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (model == NULL) {
|
||||
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
// tokenize the prompt
|
||||
|
||||
// find the number of tokens in the prompt
|
||||
|
||||
@@ -128,6 +128,8 @@ extern "C" {
|
||||
// set gradients to zero, initilize loss, and optionally reset the optimizer
|
||||
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
|
||||
|
||||
GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
|
||||
|
||||
// get underlying tensors that store data
|
||||
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
|
||||
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
|
||||
|
||||
@@ -65,6 +65,7 @@
|
||||
#include <aclnnop/aclnn_eq_tensor.h>
|
||||
#include <aclnnop/aclnn_gt_scalar.h>
|
||||
#include <aclnnop/aclnn_pow.h>
|
||||
#include <aclnnop/aclnn_grouped_matmul_v2.h>
|
||||
#include <float.h>
|
||||
|
||||
#include <cmath>
|
||||
@@ -2587,3 +2588,149 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs expert-specific matrix multiplication (MoE) with
|
||||
* floating-point precision using the CANN backend.
|
||||
*
|
||||
* This function executes a matrix multiplication operation tailored for
|
||||
* Mixture of Experts (MoE) models, where the input tensor is multiplied
|
||||
* with expert-specific weight matrices. It uses the CANN backend for
|
||||
* efficient computation and stores the result in the destination tensor `dst`.
|
||||
* The operation may leverage identity-based optimizations or routing masks
|
||||
* as part of sparse expert selection.
|
||||
*
|
||||
* @param ctx The context for executing CANN backend operations.
|
||||
* @param dst The destination tensor where the MoE multiplication result
|
||||
* will be stored.
|
||||
*
|
||||
* @note This function assumes floating-point data types and is designed for
|
||||
* MoE architectures, possibly involving sparse expert routing.
|
||||
*/
|
||||
static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
//dst [M, K, N, 1]
|
||||
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
|
||||
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
|
||||
ggml_tensor * ids = dst->src[2]; //ids [K, N]
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
// copy index from npu to cpu
|
||||
int64_t n_as = ne02; // A
|
||||
int64_t n_ids = ids->ne[0]; // K
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
|
||||
ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
|
||||
|
||||
char * src0_original = (char *) src0->data;
|
||||
char * src1_original = (char *) src1->data;
|
||||
char * dst_original = (char *) dst->data;
|
||||
size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03};
|
||||
|
||||
// src0 is F16, src1 is F32, dst is F32
|
||||
ggml_cann_pool_alloc src0_cast_allocator;
|
||||
if (src0->type == GGML_TYPE_F16) {
|
||||
src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0));
|
||||
void* src0_cast_buf = src0_cast_allocator.get();
|
||||
|
||||
size_t cast_nb[GGML_MAX_DIMS];
|
||||
cast_nb[0] = sizeof(float_t);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
|
||||
aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf,
|
||||
ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast);
|
||||
ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16);
|
||||
|
||||
src0_original = (char *) src0_cast_buf;
|
||||
memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb));
|
||||
}
|
||||
|
||||
std::vector<aclTensor*> src0_tensor_vec;
|
||||
std::vector<aclTensor*> src1_tensor_vec;
|
||||
std::vector<aclTensor*> dst_tensor_vec;
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
// src0_row [M, D] -> weight && permute
|
||||
int64_t src0_ne[2] = {ne01, ne00};
|
||||
size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]};
|
||||
// src1_row [D, 1] -> input
|
||||
int64_t src1_ne[2] = {ne10, 1};
|
||||
size_t src1_nb[2] = {nb10, nb11};
|
||||
// dst_row [M, 1] -> out
|
||||
int64_t dst_ne[2] = {ne0, 1};
|
||||
size_t dst_nb[2] = {nb0, nb1};
|
||||
|
||||
// expert index
|
||||
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
||||
|
||||
// If B = 1 (broadcast), always use 0; otherwise, use id.
|
||||
int64_t i11 = (ne11 == 1 ? 0 : id);
|
||||
int64_t i12 = iid1;
|
||||
|
||||
int64_t i1 = id;
|
||||
int64_t i2 = i12;
|
||||
|
||||
void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
|
||||
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
|
||||
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
|
||||
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr,
|
||||
ACL_FLOAT, sizeof(float),
|
||||
src0_ne, src0_nb, 2);
|
||||
aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr,
|
||||
ACL_FLOAT, sizeof(float),
|
||||
src1_ne, src1_nb, 2);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr,
|
||||
ACL_FLOAT, sizeof(float),
|
||||
dst_ne, dst_nb, 2);
|
||||
|
||||
src0_tensor_vec.push_back(acl_src0);
|
||||
src1_tensor_vec.push_back(acl_src1);
|
||||
dst_tensor_vec.push_back(acl_dst);
|
||||
}
|
||||
}
|
||||
|
||||
// GroupedMatmulV2 required tensor_list.size < 128
|
||||
size_t GROUP_SIZE = 128;
|
||||
std::vector<std::vector<aclTensor*>> src0_tensor_vec_vec;
|
||||
std::vector<std::vector<aclTensor*>> src1_tensor_vec_vec;
|
||||
std::vector<std::vector<aclTensor*>> dst_tensor_vec_vec;
|
||||
|
||||
// split and call GroupedMatmulV2
|
||||
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
|
||||
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
|
||||
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
|
||||
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
|
||||
std::vector<aclTensor*> dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end);
|
||||
|
||||
aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size());
|
||||
aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size());
|
||||
aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size());
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list,
|
||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list);
|
||||
|
||||
ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
const enum ggml_type type = dst->src[0]->type;
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
ggml_cann_mul_mat_id_fp(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for mul_mat_id");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -978,6 +978,33 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs sparse expert-based matrix multiplication using the CANN backend.
|
||||
*
|
||||
* @details This function implements a MoE-style batched matrix multiplication, where each input token
|
||||
* is routed to one or more experts, and each expert corresponds to a specific [D, M] weight matrix
|
||||
* in the source tensor `src0`. The routing indices are provided via the `ids` tensor.
|
||||
*
|
||||
* For each token (from `src1`), the function selects the corresponding expert(s) as specified by `ids`,
|
||||
* performs the matrix multiplication with the selected expert's weight submatrix (from `src0`),
|
||||
* and stores the results in `dst`. This operation is optimized and executed on the CANN backend.
|
||||
*
|
||||
* Dimensions:
|
||||
* - src0: [D, M, A, 1], where A is the number of experts
|
||||
* - src1: [D, B, N, 1], where N is batch size and B is the slot count per sample
|
||||
* - ids : [K, N], where K is the number of experts each token is routed to
|
||||
* - dst : [M, K, N, 1], output tensor storing the result of expert × token multiplication
|
||||
*
|
||||
* The function handles two main modes:
|
||||
* - If `ne12 == 1`, a simpler per-token loop is used.
|
||||
* - TODO: If `ne12 > 1`, grouped multiplication and memory copying is used for efficiency.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the expert-weighted token outputs are stored.
|
||||
* Expected to be of shape [M, K, N, 1].
|
||||
*/
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Applies a element-wise operation to two input tensors using the CANN
|
||||
* backend.
|
||||
|
||||
@@ -1672,7 +1672,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
ggml_cann_mul_mat(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return false;
|
||||
ggml_cann_mul_mat_id(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SCALE:
|
||||
ggml_cann_scale(ctx, dst);
|
||||
break;
|
||||
@@ -2030,7 +2031,13 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
}
|
||||
}
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return false;
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
// embedding
|
||||
case GGML_OP_GET_ROWS: {
|
||||
switch (op->src[0]->type) {
|
||||
|
||||
@@ -576,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
|
||||
return opt_ctx->static_graphs;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
|
||||
return opt_ctx->inputs;
|
||||
}
|
||||
@@ -842,6 +846,7 @@ void ggml_opt_epoch(
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval) {
|
||||
GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
|
||||
struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
|
||||
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
||||
struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
|
||||
|
||||
@@ -54,6 +54,11 @@ if (Vulkan_FOUND)
|
||||
-DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY}
|
||||
)
|
||||
|
||||
set(VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS "")
|
||||
if (CMAKE_BUILD_TYPE AND CMAKE_BUILD_TYPE MATCHES "Debug|Release|MinSizeRel|RelWithDebInfo")
|
||||
list(APPEND VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS --config=${CMAKE_BUILD_TYPE})
|
||||
endif()
|
||||
|
||||
# Test all shader extensions
|
||||
test_shader_extension_support(
|
||||
"GL_KHR_cooperative_matrix"
|
||||
@@ -149,7 +154,7 @@ if (Vulkan_FOUND)
|
||||
vulkan-shaders-gen
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
|
||||
CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS}
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build .
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build . ${VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS}
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
||||
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
||||
)
|
||||
|
||||
@@ -5872,10 +5872,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
vk_pipeline *pipelines;
|
||||
bool small_rows = N <= get_fa_num_small_rows(path);
|
||||
|
||||
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
||||
// So use scalar instead.
|
||||
if (small_rows && path == FA_COOPMAT1) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
// scalar is faster than coopmat2 when N==1
|
||||
if (N == 1 && path == FA_COOPMAT2) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
switch (path) {
|
||||
|
||||
@@ -9,60 +9,13 @@
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
|
||||
#include "types.comp"
|
||||
#include "flash_attn_base.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
|
||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||
const uint32_t D_per_thread = D / D_split;
|
||||
|
||||
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
uint32_t ne1;
|
||||
uint32_t ne2;
|
||||
uint32_t ne3;
|
||||
|
||||
uint32_t neq2;
|
||||
uint32_t neq3;
|
||||
uint32_t nek2;
|
||||
uint32_t nek3;
|
||||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t gqa_ratio;
|
||||
uint32_t split_kv;
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer Q {float data_q[];};
|
||||
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
||||
@@ -71,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
||||
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
||||
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
@@ -114,27 +34,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Store column zero. This is used to save per-row m and L values for split_k.
|
||||
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c == 0) {
|
||||
uint32_t offset = iq2 + r;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Load the slope matrix, indexed by Q's dimension 2.
|
||||
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
||||
shared vec4 tmpshv4[WorkGroupSize];
|
||||
|
||||
@@ -146,58 +45,12 @@ void main() {
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
const uint32_t N = p.N;
|
||||
const uint32_t KV = p.KV;
|
||||
init_indices();
|
||||
|
||||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
||||
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
|
||||
|
||||
uint32_t i = gl_WorkGroupID.x;
|
||||
uint32_t split_k_index = 0;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
split_k_index = gl_WorkGroupID.x;
|
||||
}
|
||||
|
||||
const uint32_t Tr = CEIL_DIV(N, Br);
|
||||
|
||||
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
||||
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
||||
|
||||
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
||||
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
||||
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
||||
const uint32_t iq3 = gl_WorkGroupID.z;
|
||||
|
||||
// broadcast factors
|
||||
const uint32_t rk2 = p.neq2/p.nek2;
|
||||
const uint32_t rk3 = p.neq3/p.nek3;
|
||||
|
||||
const uint32_t rv2 = p.neq2/p.nev2;
|
||||
const uint32_t rv3 = p.neq3/p.nev3;
|
||||
|
||||
// k indices
|
||||
const uint32_t ik3 = iq3 / rk3;
|
||||
const uint32_t ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const uint32_t iv3 = iq3 / rv3;
|
||||
const uint32_t iv2 = iq2 / rv2;
|
||||
|
||||
// nb?1 are already divided by the type size and are in units of elements.
|
||||
// When using grouped query attention, Q is indexed by iq2, so the stride
|
||||
// should be nb02 (which is in bytes).
|
||||
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||
uint32_t k_stride = p.nb11;
|
||||
uint32_t v_stride = p.nb21;
|
||||
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||
// that prevents the compiler from folding the "&" through the select
|
||||
// and breaking the alignment detection.
|
||||
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
|
||||
162
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
Normal file
162
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
Normal file
@@ -0,0 +1,162 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
layout (constant_id = 4) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
uint32_t ne1;
|
||||
uint32_t ne2;
|
||||
uint32_t ne3;
|
||||
|
||||
uint32_t neq2;
|
||||
uint32_t neq3;
|
||||
uint32_t nek2;
|
||||
uint32_t nek3;
|
||||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t gqa_ratio;
|
||||
uint32_t split_kv;
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
|
||||
// Store column zero. This is used to save per-row m and L values for split_k.
|
||||
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c == 0) {
|
||||
uint32_t offset = iq2 + r;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Load the slope matrix, indexed by Q's dimension 2.
|
||||
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
|
||||
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
|
||||
q_stride, k_stride, v_stride, m_stride;
|
||||
|
||||
void init_indices()
|
||||
{
|
||||
N = p.N;
|
||||
KV = p.KV;
|
||||
|
||||
i = gl_WorkGroupID.x;
|
||||
split_k_index = 0;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
split_k_index = gl_WorkGroupID.x;
|
||||
}
|
||||
|
||||
Tr = CEIL_DIV(N, Br);
|
||||
|
||||
start_j = split_k_index * p.split_kv / Bc;
|
||||
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
||||
|
||||
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
||||
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
||||
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
||||
iq3 = gl_WorkGroupID.z;
|
||||
|
||||
// broadcast factors
|
||||
rk2 = p.neq2/p.nek2;
|
||||
rk3 = p.neq3/p.nek3;
|
||||
|
||||
rv2 = p.neq2/p.nev2;
|
||||
rv3 = p.neq3/p.nev3;
|
||||
|
||||
// k indices
|
||||
ik3 = iq3 / rk3;
|
||||
ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
iv3 = iq3 / rv3;
|
||||
iv2 = iq2 / rv2;
|
||||
|
||||
// nb?1 are already divided by the type size and are in units of elements.
|
||||
// When using grouped query attention, Q is indexed by iq2, so the stride
|
||||
// should be nb02 (which is in bytes).
|
||||
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||
k_stride = p.nb11;
|
||||
v_stride = p.nb21;
|
||||
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||
// that prevents the compiler from folding the "&" through the select
|
||||
// and breaking the alignment detection.
|
||||
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||
}
|
||||
@@ -11,14 +11,7 @@
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
|
||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||
#include "flash_attn_base.comp"
|
||||
|
||||
const uint32_t D_per_thread = D / D_split;
|
||||
const uint32_t row_split = 4;
|
||||
@@ -26,46 +19,6 @@ const uint32_t rows_per_thread = Br / row_split;
|
||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
uint32_t ne1;
|
||||
uint32_t ne2;
|
||||
uint32_t ne3;
|
||||
|
||||
uint32_t neq2;
|
||||
uint32_t neq3;
|
||||
uint32_t nek2;
|
||||
uint32_t nek3;
|
||||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t gqa_ratio;
|
||||
uint32_t split_kv;
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer Q {float data_q[];};
|
||||
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
||||
@@ -74,39 +27,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
||||
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
||||
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
@@ -117,27 +37,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Store column zero. This is used to save per-row m and L values for split_k.
|
||||
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c == 0) {
|
||||
uint32_t offset = iq2 + r;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Load the slope matrix, indexed by Q's dimension 2.
|
||||
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
||||
const uint32_t MatBr = 16;
|
||||
const uint32_t MatBc = 16;
|
||||
@@ -162,9 +61,9 @@ void main() {
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
init_indices();
|
||||
|
||||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
const uint32_t N = p.N;
|
||||
const uint32_t KV = p.KV;
|
||||
|
||||
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
||||
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
||||
@@ -173,51 +72,6 @@ void main() {
|
||||
|
||||
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||
|
||||
uint32_t i = gl_WorkGroupID.x;
|
||||
uint32_t split_k_index = 0;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
split_k_index = gl_WorkGroupID.x;
|
||||
}
|
||||
|
||||
const uint32_t Tr = CEIL_DIV(N, Br);
|
||||
|
||||
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
||||
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
||||
|
||||
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
||||
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
||||
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
||||
const uint32_t iq3 = gl_WorkGroupID.z;
|
||||
|
||||
// broadcast factors
|
||||
const uint32_t rk2 = p.neq2/p.nek2;
|
||||
const uint32_t rk3 = p.neq3/p.nek3;
|
||||
|
||||
const uint32_t rv2 = p.neq2/p.nev2;
|
||||
const uint32_t rv3 = p.neq3/p.nev3;
|
||||
|
||||
// k indices
|
||||
const uint32_t ik3 = iq3 / rk3;
|
||||
const uint32_t ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const uint32_t iv3 = iq3 / rv3;
|
||||
const uint32_t iv2 = iq2 / rv2;
|
||||
|
||||
// nb?1 are already divided by the type size and are in units of elements.
|
||||
// When using grouped query attention, Q is indexed by iq2, so the stride
|
||||
// should be nb02 (which is in bytes).
|
||||
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||
uint32_t k_stride = p.nb11;
|
||||
uint32_t v_stride = p.nb21;
|
||||
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||
// that prevents the compiler from folding the "&" through the select
|
||||
// and breaking the alignment detection.
|
||||
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
|
||||
@@ -18,62 +18,12 @@
|
||||
|
||||
#include "types.comp"
|
||||
#include "dequant_funcs_cm2.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 1) const uint32_t Br = 32;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
uint32_t ne1;
|
||||
uint32_t ne2;
|
||||
uint32_t ne3;
|
||||
|
||||
uint32_t neq2;
|
||||
uint32_t neq3;
|
||||
uint32_t nek2;
|
||||
uint32_t nek3;
|
||||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t gqa_ratio;
|
||||
uint32_t split_kv;
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
#include "flash_attn_base.comp"
|
||||
|
||||
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
||||
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
||||
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
|
||||
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
||||
return max(x, y);
|
||||
@@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Store column zero. This is used to save per-row m and L values for split_k.
|
||||
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c == 0) {
|
||||
uint32_t offset = iq2 + r;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Load the slope matrix, indexed by Q's dimension 2.
|
||||
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint32_t N = p.N;
|
||||
const uint32_t KV = p.KV;
|
||||
|
||||
uint32_t i = gl_WorkGroupID.x;
|
||||
uint32_t split_k_index = 0;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
split_k_index = gl_WorkGroupID.x;
|
||||
}
|
||||
|
||||
const uint32_t Tr = CEIL_DIV(N, Br);
|
||||
|
||||
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
||||
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
||||
|
||||
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
||||
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
||||
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
||||
const uint32_t iq3 = gl_WorkGroupID.z;
|
||||
|
||||
// broadcast factors
|
||||
const uint32_t rk2 = p.neq2/p.nek2;
|
||||
const uint32_t rk3 = p.neq3/p.nek3;
|
||||
|
||||
const uint32_t rv2 = p.neq2/p.nev2;
|
||||
const uint32_t rv3 = p.neq3/p.nev3;
|
||||
|
||||
// k indices
|
||||
const uint32_t ik3 = iq3 / rk3;
|
||||
const uint32_t ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const uint32_t iv3 = iq3 / rv3;
|
||||
const uint32_t iv2 = iq2 / rv2;
|
||||
init_indices();
|
||||
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
|
||||
@@ -195,17 +90,6 @@ void main() {
|
||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
||||
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
||||
|
||||
// nb?1 are already divided by the type size and are in units of elements.
|
||||
// When using grouped query attention, Q is indexed by iq2, so the stride
|
||||
// should be nb02 (which is in bytes).
|
||||
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||
uint32_t k_stride = p.nb11;
|
||||
uint32_t v_stride = p.nb21;
|
||||
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||
// that prevents the compiler from folding the "&" through the select
|
||||
// and breaking the alignment detection.
|
||||
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||
{
|
||||
|
||||
@@ -64,12 +64,17 @@
|
||||
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
|
||||
float ggml_table_f32_f16[1 << 16];
|
||||
|
||||
#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \
|
||||
(!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH))
|
||||
#if defined(__linux__) || \
|
||||
defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
|
||||
(defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
|
||||
|
||||
#include <unistd.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/wait.h>
|
||||
#if defined(__linux__)
|
||||
#include <sys/prctl.h>
|
||||
#endif
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#include <unwind.h>
|
||||
@@ -133,10 +138,36 @@ static void ggml_print_backtrace(void) {
|
||||
if (GGML_NO_BACKTRACE) {
|
||||
return;
|
||||
}
|
||||
char attach[32];
|
||||
snprintf(attach, sizeof(attach), "attach %d", getpid());
|
||||
int pid = fork();
|
||||
if (pid == 0) {
|
||||
#if defined(__linux__)
|
||||
FILE * f = fopen("/proc/self/status", "r");
|
||||
size_t size = 0;
|
||||
char * line = NULL;
|
||||
ssize_t length = 0;
|
||||
while ((length = getline(&line, &size, f)) > 0) {
|
||||
if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) &&
|
||||
(length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) {
|
||||
// Already being debugged, and the breakpoint is the later abort()
|
||||
free(line);
|
||||
fclose(f);
|
||||
return;
|
||||
}
|
||||
}
|
||||
free(line);
|
||||
fclose(f);
|
||||
int lock[2] = { -1, -1 };
|
||||
(void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER
|
||||
#endif
|
||||
const int parent_pid = getpid();
|
||||
const int child_pid = fork();
|
||||
if (child_pid < 0) { // error
|
||||
return;
|
||||
} else if (child_pid == 0) { // child
|
||||
char attach[32];
|
||||
snprintf(attach, sizeof(attach), "attach %d", parent_pid);
|
||||
#if defined(__linux__)
|
||||
close(lock[1]);
|
||||
(void) !read(lock[0], lock, 1);
|
||||
#endif
|
||||
// try gdb
|
||||
execlp("gdb", "gdb", "--batch",
|
||||
"-ex", "set style enabled on",
|
||||
@@ -149,18 +180,18 @@ static void ggml_print_backtrace(void) {
|
||||
execlp("lldb", "lldb", "--batch",
|
||||
"-o", "bt",
|
||||
"-o", "quit",
|
||||
"-p", attach,
|
||||
"-p", &attach[sizeof("attach ") - 1],
|
||||
(char *) NULL);
|
||||
exit(EXIT_FAILURE);
|
||||
} else {
|
||||
int wstatus;
|
||||
waitpid(pid, &wstatus, 0);
|
||||
if (WIFEXITED(wstatus)) {
|
||||
if (WEXITSTATUS(wstatus) == EXIT_FAILURE) {
|
||||
// gdb failed, fallback to backtrace_symbols
|
||||
ggml_print_backtrace_symbols();
|
||||
}
|
||||
}
|
||||
// gdb failed, fallback to backtrace_symbols
|
||||
ggml_print_backtrace_symbols();
|
||||
_Exit(0);
|
||||
} else { // parent
|
||||
#if defined(__linux__)
|
||||
prctl(PR_SET_PTRACER, child_pid);
|
||||
close(lock[1]);
|
||||
close(lock[0]);
|
||||
#endif
|
||||
waitpid(child_pid, NULL, 0);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
||||
@@ -1 +1 @@
|
||||
9b048bb72b811f50b0c30d9e5c84d6ff9f4bf005
|
||||
7c06c10c532a6cda913c17fc56341e8880ae341d
|
||||
|
||||
@@ -13,6 +13,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
|
||||
* Multimodal ([documentation](../../docs/multimodal.md)) / with OpenAI-compatible API support
|
||||
* Monitoring endpoints
|
||||
* Schema-constrained JSON response format
|
||||
* Prefilling of assistant messages similar to the Claude API
|
||||
* [Function calling](../../docs/function-calling.md) / tool use for ~any model
|
||||
* Speculative decoding
|
||||
* Easy-to-use web UI
|
||||
@@ -175,6 +176,7 @@ The project is under active development, and we are [looking for feedback and co
|
||||
| `--reasoning-format FORMAT` | reasoning format (default: deepseek; allowed values: deepseek, none)<br/>controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).<br/>only supported for non-streamed responses<br/>(env: LLAMA_ARG_THINK) |
|
||||
| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, falcon3, gemma, gigachat, glmedge, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
|
||||
| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, falcon3, gemma, gigachat, glmedge, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) |
|
||||
| `--no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)<br/>when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled<br/>(env: LLAMA_ARG_NO_PREFILL_ASSISTANT) |
|
||||
| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
|
||||
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
|
||||
| `--draft-max, --draft, --draft-n N` | number of tokens to draft for speculative decoding (default: 16)<br/>(env: LLAMA_ARG_DRAFT_MAX) |
|
||||
|
||||
Binary file not shown.
@@ -2251,6 +2251,14 @@ struct server_context {
|
||||
slot.has_next_token = true;
|
||||
}
|
||||
|
||||
// if context shifting is disabled, make sure that we don't run out of context
|
||||
if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
|
||||
slot.stop = STOP_TYPE_LIMIT;
|
||||
slot.has_next_token = false;
|
||||
|
||||
SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx);
|
||||
}
|
||||
|
||||
// check the limits
|
||||
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
|
||||
slot.stop = STOP_TYPE_LIMIT;
|
||||
@@ -4340,6 +4348,7 @@ int main(int argc, char ** argv) {
|
||||
json data = oaicompat_completion_params_parse(
|
||||
body,
|
||||
params.use_jinja,
|
||||
params.prefill_assistant,
|
||||
params.reasoning_format,
|
||||
ctx_server.chat_templates.get(),
|
||||
ctx_server.mctx,
|
||||
@@ -4361,6 +4370,7 @@ int main(int argc, char ** argv) {
|
||||
json data = oaicompat_completion_params_parse(
|
||||
body,
|
||||
params.use_jinja,
|
||||
params.prefill_assistant,
|
||||
params.reasoning_format,
|
||||
ctx_server.chat_templates.get(),
|
||||
ctx_server.mctx,
|
||||
|
||||
@@ -65,3 +65,21 @@ def test_ctx_shift_disabled_long_prompt():
|
||||
assert res.status_code != 200
|
||||
assert "error" in res.body
|
||||
assert "exceeds the available context size" in res.body["error"]["message"]
|
||||
|
||||
def test_ctx_shift_disabled_stream():
|
||||
global server
|
||||
server.disable_ctx_shift = True
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/v1/completions", data={
|
||||
"n_predict": 256,
|
||||
"prompt": "Once",
|
||||
"stream": True,
|
||||
})
|
||||
content = ""
|
||||
for data in res:
|
||||
choice = data["choices"][0]
|
||||
if choice["finish_reason"] == "length":
|
||||
assert len(content) > 0
|
||||
else:
|
||||
assert choice["finish_reason"] is None
|
||||
content += choice["text"]
|
||||
|
||||
@@ -583,6 +583,7 @@ static json oaicompat_completion_params_parse(const json & body) {
|
||||
static json oaicompat_completion_params_parse(
|
||||
const json & body, /* openai api json semantics */
|
||||
bool use_jinja,
|
||||
bool prefill_assistant,
|
||||
common_reasoning_format reasoning_format,
|
||||
const struct common_chat_templates * tmpls,
|
||||
bool allow_non_text,
|
||||
@@ -732,7 +733,7 @@ static json oaicompat_completion_params_parse(
|
||||
|
||||
// if the assistant message appears at the end of list, we do not add end-of-turn token
|
||||
// for ex. this can be useful to modify the reasoning process in reasoning models
|
||||
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant";
|
||||
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && prefill_assistant;
|
||||
common_chat_msg last_message;
|
||||
if (prefill_assistant_message) {
|
||||
last_message = inputs.messages.back();
|
||||
|
||||
@@ -28,13 +28,13 @@ function AppLayout() {
|
||||
return (
|
||||
<>
|
||||
<Sidebar />
|
||||
<div
|
||||
<main
|
||||
className="drawer-content grow flex flex-col h-screen w-screen mx-auto px-4 overflow-auto bg-base-100"
|
||||
id="main-scroll"
|
||||
>
|
||||
<Header />
|
||||
<Outlet />
|
||||
</div>
|
||||
</main>
|
||||
{
|
||||
<SettingDialog
|
||||
show={showSettings}
|
||||
|
||||
@@ -18,16 +18,26 @@ export default function ChatInputExtraContextItem({
|
||||
if (!items) return null;
|
||||
|
||||
return (
|
||||
<div className="flex flex-row gap-4 overflow-x-auto py-2 px-1 mb-1">
|
||||
<div
|
||||
className="flex flex-row gap-4 overflow-x-auto py-2 px-1 mb-1"
|
||||
role="group"
|
||||
aria-description="Selected files"
|
||||
>
|
||||
{items.map((item, i) => (
|
||||
<div
|
||||
className="indicator"
|
||||
key={i}
|
||||
onClick={() => clickToShow && setShow(i)}
|
||||
tabIndex={0}
|
||||
aria-description={
|
||||
clickToShow ? `Click to show: ${item.name}` : undefined
|
||||
}
|
||||
role={clickToShow ? 'button' : 'menuitem'}
|
||||
>
|
||||
{removeItem && (
|
||||
<div className="indicator-item indicator-top">
|
||||
<button
|
||||
aria-label="Remove file"
|
||||
className="btn btn-neutral btn-sm w-4 h-4 p-0 rounded-full"
|
||||
onClick={() => removeItem(i)}
|
||||
>
|
||||
@@ -46,13 +56,16 @@ export default function ChatInputExtraContextItem({
|
||||
<>
|
||||
<img
|
||||
src={item.base64Url}
|
||||
alt={item.name}
|
||||
alt={`Preview image for ${item.name}`}
|
||||
className="w-14 h-14 object-cover rounded-md"
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="w-14 h-14 flex items-center justify-center">
|
||||
<div
|
||||
className="w-14 h-14 flex items-center justify-center"
|
||||
aria-description="Document icon"
|
||||
>
|
||||
<DocumentTextIcon className="h-8 w-14 text-base-content/50" />
|
||||
</div>
|
||||
|
||||
@@ -66,16 +79,25 @@ export default function ChatInputExtraContextItem({
|
||||
))}
|
||||
|
||||
{showingItem && (
|
||||
<dialog className="modal modal-open">
|
||||
<dialog
|
||||
className="modal modal-open"
|
||||
aria-description={`Preview ${showingItem.name}`}
|
||||
>
|
||||
<div className="modal-box">
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<b>{showingItem.name ?? 'Extra content'}</b>
|
||||
<button className="btn btn-ghost btn-sm">
|
||||
<button
|
||||
className="btn btn-ghost btn-sm"
|
||||
aria-label="Close preview dialog"
|
||||
>
|
||||
<XMarkIcon className="h-5 w-5" onClick={() => setShow(-1)} />
|
||||
</button>
|
||||
</div>
|
||||
{showingItem.type === 'imageFile' ? (
|
||||
<img src={showingItem.base64Url} alt={showingItem.name} />
|
||||
<img
|
||||
src={showingItem.base64Url}
|
||||
alt={`Preview image for ${showingItem.name}`}
|
||||
/>
|
||||
) : (
|
||||
<div className="overflow-x-auto">
|
||||
<pre className="whitespace-pre-wrap break-words text-sm">
|
||||
|
||||
@@ -83,13 +83,20 @@ export default function ChatMessage({
|
||||
|
||||
if (!viewingChat) return null;
|
||||
|
||||
const isUser = msg.role === 'user';
|
||||
|
||||
return (
|
||||
<div className="group" id={id}>
|
||||
<div
|
||||
className="group"
|
||||
id={id}
|
||||
role="group"
|
||||
aria-description={`Message from ${msg.role}`}
|
||||
>
|
||||
<div
|
||||
className={classNames({
|
||||
chat: true,
|
||||
'chat-start': msg.role !== 'user',
|
||||
'chat-end': msg.role === 'user',
|
||||
'chat-start': !isUser,
|
||||
'chat-end': isUser,
|
||||
})}
|
||||
>
|
||||
{msg.extra && msg.extra.length > 0 && (
|
||||
@@ -99,7 +106,7 @@ export default function ChatMessage({
|
||||
<div
|
||||
className={classNames({
|
||||
'chat-bubble markdown': true,
|
||||
'chat-bubble bg-transparent': msg.role !== 'user',
|
||||
'chat-bubble bg-transparent': !isUser,
|
||||
})}
|
||||
>
|
||||
{/* textarea for editing message */}
|
||||
@@ -142,7 +149,7 @@ export default function ChatMessage({
|
||||
) : (
|
||||
<>
|
||||
{/* render message as markdown */}
|
||||
<div dir="auto">
|
||||
<div dir="auto" tabIndex={0}>
|
||||
{thought && (
|
||||
<ThoughtProcess
|
||||
isThinking={!!isThinking && !!isPending}
|
||||
@@ -196,13 +203,18 @@ export default function ChatMessage({
|
||||
})}
|
||||
>
|
||||
{siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && (
|
||||
<div className="flex gap-1 items-center opacity-60 text-sm">
|
||||
<div
|
||||
className="flex gap-1 items-center opacity-60 text-sm"
|
||||
role="navigation"
|
||||
aria-description={`Message version ${siblingCurrIdx + 1} of ${siblingLeafNodeIds.length}`}
|
||||
>
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-sm btn-ghost p-1': true,
|
||||
'opacity-20': !prevSibling,
|
||||
})}
|
||||
onClick={() => prevSibling && onChangeSibling(prevSibling)}
|
||||
aria-label="Previous message version"
|
||||
>
|
||||
<ChevronLeftIcon className="h-4 w-4" />
|
||||
</button>
|
||||
@@ -215,6 +227,7 @@ export default function ChatMessage({
|
||||
'opacity-20': !nextSibling,
|
||||
})}
|
||||
onClick={() => nextSibling && onChangeSibling(nextSibling)}
|
||||
aria-label="Next message version"
|
||||
>
|
||||
<ChevronRightIcon className="h-4 w-4" />
|
||||
</button>
|
||||
@@ -223,7 +236,7 @@ export default function ChatMessage({
|
||||
{/* user message */}
|
||||
{msg.role === 'user' && (
|
||||
<BtnWithTooltips
|
||||
className="btn-mini show-on-hover w-8 h-8"
|
||||
className="btn-mini w-8 h-8"
|
||||
onClick={() => setEditingContent(msg.content)}
|
||||
disabled={msg.content === null}
|
||||
tooltipsContent="Edit message"
|
||||
@@ -236,7 +249,7 @@ export default function ChatMessage({
|
||||
<>
|
||||
{!isPending && (
|
||||
<BtnWithTooltips
|
||||
className="btn-mini show-on-hover w-8 h-8"
|
||||
className="btn-mini w-8 h-8"
|
||||
onClick={() => {
|
||||
if (msg.content !== null) {
|
||||
onRegenerateMessage(msg as Message);
|
||||
@@ -250,10 +263,7 @@ export default function ChatMessage({
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<CopyButton
|
||||
className="btn-mini show-on-hover w-8 h-8"
|
||||
content={msg.content}
|
||||
/>
|
||||
<CopyButton className="btn-mini w-8 h-8" content={msg.content} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -271,6 +281,8 @@ function ThoughtProcess({
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
role="button"
|
||||
aria-label="Toggle thought process display"
|
||||
tabIndex={0}
|
||||
className={classNames({
|
||||
'collapse bg-none': true,
|
||||
@@ -292,7 +304,11 @@ function ThoughtProcess({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="collapse-content text-base-content/70 text-sm p-1">
|
||||
<div
|
||||
className="collapse-content text-base-content/70 text-sm p-1"
|
||||
tabIndex={0}
|
||||
aria-description="Thought process content"
|
||||
>
|
||||
<div className="border-l-2 border-base-content/20 pl-4 mb-4">
|
||||
<MarkdownDisplay content={content} />
|
||||
</div>
|
||||
|
||||
@@ -279,7 +279,11 @@ export default function ChatScreen() {
|
||||
function ServerInfo() {
|
||||
const { serverProps } = useAppContext();
|
||||
return (
|
||||
<div className="card card-sm shadow-sm border-1 border-base-content/20 text-base-content/70 mb-6">
|
||||
<div
|
||||
className="card card-sm shadow-sm border-1 border-base-content/20 text-base-content/70 mb-6"
|
||||
tabIndex={0}
|
||||
aria-description="Server information"
|
||||
>
|
||||
<div className="card-body">
|
||||
<b>Server Info</b>
|
||||
<p>
|
||||
@@ -311,6 +315,8 @@ function ChatInput({
|
||||
|
||||
return (
|
||||
<div
|
||||
role="group"
|
||||
aria-label="Chat input"
|
||||
className={classNames({
|
||||
'flex items-end pt-8 pb-6 sticky bottom-0 bg-base-100': true,
|
||||
'opacity-50': isDrag, // simply visual feedback to inform user that the file will be accepted
|
||||
@@ -400,13 +406,15 @@ function ChatInput({
|
||||
'btn w-8 h-8 p-0 rounded-full': true,
|
||||
'btn-disabled': isGenerating,
|
||||
})}
|
||||
aria-label="Upload file"
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
>
|
||||
<PaperClipIcon className="h-5 w-5" />
|
||||
</label>
|
||||
<input
|
||||
id="file-upload"
|
||||
type="file"
|
||||
className="hidden"
|
||||
disabled={isGenerating}
|
||||
{...getInputProps()}
|
||||
hidden
|
||||
@@ -422,6 +430,7 @@ function ChatInput({
|
||||
<button
|
||||
className="btn btn-primary w-8 h-8 p-0 rounded-full"
|
||||
onClick={onSend}
|
||||
aria-label="Send message"
|
||||
>
|
||||
<ArrowUpIcon className="h-5 w-5" />
|
||||
</button>
|
||||
|
||||
@@ -38,8 +38,12 @@ export default function Header() {
|
||||
|
||||
{/* action buttons (top right) */}
|
||||
<div className="flex items-center">
|
||||
<div className="tooltip tooltip-bottom" data-tip="Settings">
|
||||
<button className="btn" onClick={() => setShowSettings(true)}>
|
||||
<div
|
||||
className="tooltip tooltip-bottom"
|
||||
data-tip="Settings"
|
||||
onClick={() => setShowSettings(true)}
|
||||
>
|
||||
<button className="btn" aria-hidden={true}>
|
||||
{/* settings button */}
|
||||
<Cog8ToothIcon className="w-5 h-5" />
|
||||
</button>
|
||||
|
||||
@@ -335,14 +335,22 @@ export default function SettingDialog({
|
||||
};
|
||||
|
||||
return (
|
||||
<dialog className={classNames({ modal: true, 'modal-open': show })}>
|
||||
<dialog
|
||||
className={classNames({ modal: true, 'modal-open': show })}
|
||||
aria-label="Settings dialog"
|
||||
>
|
||||
<div className="modal-box w-11/12 max-w-3xl">
|
||||
<h3 className="text-lg font-bold mb-6">Settings</h3>
|
||||
<div className="flex flex-col md:flex-row h-[calc(90vh-12rem)]">
|
||||
{/* Left panel, showing sections - Desktop version */}
|
||||
<div className="hidden md:flex flex-col items-stretch pr-4 mr-4 border-r-2 border-base-200">
|
||||
<div
|
||||
className="hidden md:flex flex-col items-stretch pr-4 mr-4 border-r-2 border-base-200"
|
||||
role="complementary"
|
||||
aria-description="Settings sections"
|
||||
tabIndex={0}
|
||||
>
|
||||
{SETTING_SECTIONS.map((section, idx) => (
|
||||
<div
|
||||
<button
|
||||
key={idx}
|
||||
className={classNames({
|
||||
'btn btn-ghost justify-start font-normal w-44 mb-1': true,
|
||||
@@ -352,12 +360,16 @@ export default function SettingDialog({
|
||||
dir="auto"
|
||||
>
|
||||
{section.title}
|
||||
</div>
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Left panel, showing sections - Mobile version */}
|
||||
<div className="md:hidden flex flex-row gap-2 mb-4">
|
||||
{/* This menu is skipped on a11y, otherwise it's repeated the desktop version */}
|
||||
<div
|
||||
className="md:hidden flex flex-row gap-2 mb-4"
|
||||
aria-disabled={true}
|
||||
>
|
||||
<details className="dropdown">
|
||||
<summary className="btn bt-sm w-full m-1">
|
||||
{SETTING_SECTIONS[sectionIdx].title}
|
||||
|
||||
@@ -50,44 +50,72 @@ export default function Sidebar() {
|
||||
id="toggle-drawer"
|
||||
type="checkbox"
|
||||
className="drawer-toggle"
|
||||
aria-label="Toggle sidebar"
|
||||
defaultChecked
|
||||
/>
|
||||
|
||||
<div className="drawer-side h-screen lg:h-screen z-50 lg:max-w-64">
|
||||
<div
|
||||
className="drawer-side h-screen lg:h-screen z-50 lg:max-w-64"
|
||||
role="complementary"
|
||||
aria-label="Sidebar"
|
||||
tabIndex={0}
|
||||
>
|
||||
<label
|
||||
htmlFor="toggle-drawer"
|
||||
aria-label="close sidebar"
|
||||
aria-label="Close sidebar"
|
||||
className="drawer-overlay"
|
||||
></label>
|
||||
|
||||
<a
|
||||
href="#main-scroll"
|
||||
className="absolute -left-80 top-0 w-1 h-1 overflow-hidden"
|
||||
>
|
||||
Skip to main content
|
||||
</a>
|
||||
|
||||
<div className="flex flex-col bg-base-200 min-h-full max-w-64 py-4 px-4">
|
||||
<div className="flex flex-row items-center justify-between mb-4 mt-4">
|
||||
<h2 className="font-bold ml-4">Conversations</h2>
|
||||
<h2 className="font-bold ml-4" role="heading">
|
||||
Conversations
|
||||
</h2>
|
||||
|
||||
{/* close sidebar button */}
|
||||
<label htmlFor="toggle-drawer" className="btn btn-ghost lg:hidden">
|
||||
<label
|
||||
htmlFor="toggle-drawer"
|
||||
className="btn btn-ghost lg:hidden"
|
||||
aria-label="Close sidebar"
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
>
|
||||
<XMarkIcon className="w-5 h-5" />
|
||||
</label>
|
||||
</div>
|
||||
|
||||
{/* new conversation button */}
|
||||
<div
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-ghost justify-start px-2': true,
|
||||
'btn-soft': !currConv,
|
||||
})}
|
||||
onClick={() => navigate('/')}
|
||||
aria-label="New conversation"
|
||||
>
|
||||
<PencilSquareIcon className="w-5 h-5" />
|
||||
New conversation
|
||||
</div>
|
||||
</button>
|
||||
|
||||
{/* list of conversations */}
|
||||
{groupedConv.map((group, i) => (
|
||||
<div key={i}>
|
||||
<div key={i} role="group">
|
||||
{/* group name (by date) */}
|
||||
{group.title ? (
|
||||
// we use btn class here to make sure that the padding/margin are aligned with the other items
|
||||
<b className="btn btn-ghost btn-xs bg-none btn-disabled block text-xs text-base-content text-start px-2 mb-0 mt-6 font-bold">
|
||||
<b
|
||||
className="btn btn-ghost btn-xs bg-none btn-disabled block text-xs text-base-content text-start px-2 mb-0 mt-6 font-bold"
|
||||
role="note"
|
||||
aria-description={group.title}
|
||||
tabIndex={0}
|
||||
>
|
||||
{group.title}
|
||||
</b>
|
||||
) : (
|
||||
@@ -184,20 +212,23 @@ function ConversationItem({
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
role="menuitem"
|
||||
tabIndex={0}
|
||||
aria-label={conv.name}
|
||||
className={classNames({
|
||||
'group flex flex-row btn btn-ghost justify-start items-center font-normal px-2 h-9':
|
||||
true,
|
||||
'btn-soft': isCurrConv,
|
||||
})}
|
||||
>
|
||||
<div
|
||||
<button
|
||||
key={conv.id}
|
||||
className="w-full overflow-hidden truncate text-start"
|
||||
onClick={onSelect}
|
||||
dir="auto"
|
||||
>
|
||||
{conv.name}
|
||||
</div>
|
||||
</button>
|
||||
<div className="dropdown dropdown-end h-5">
|
||||
<BtnWithTooltips
|
||||
// on mobile, we always show the ellipsis icon
|
||||
@@ -211,22 +242,23 @@ function ConversationItem({
|
||||
</BtnWithTooltips>
|
||||
{/* dropdown menu */}
|
||||
<ul
|
||||
aria-label="More options"
|
||||
tabIndex={0}
|
||||
className="dropdown-content menu bg-base-100 rounded-box z-[1] p-2 shadow"
|
||||
>
|
||||
<li onClick={onRename}>
|
||||
<li onClick={onRename} tabIndex={0}>
|
||||
<a>
|
||||
<PencilIcon className="w-4 h-4" />
|
||||
Rename
|
||||
</a>
|
||||
</li>
|
||||
<li onClick={onDownload}>
|
||||
<li onClick={onDownload} tabIndex={0}>
|
||||
<a>
|
||||
<ArrowDownTrayIcon className="w-4 h-4" />
|
||||
Download
|
||||
</a>
|
||||
</li>
|
||||
<li className="text-error" onClick={onDelete}>
|
||||
<li className="text-error" onClick={onDelete} tabIndex={0}>
|
||||
<a>
|
||||
<TrashIcon className="w-4 h-4" />
|
||||
Delete
|
||||
|
||||
@@ -34,9 +34,6 @@ html {
|
||||
/* TODO: fix markdown table */
|
||||
}
|
||||
|
||||
.show-on-hover {
|
||||
@apply md:opacity-0 md:group-hover:opacity-100;
|
||||
}
|
||||
.btn-mini {
|
||||
@apply cursor-pointer;
|
||||
}
|
||||
|
||||
@@ -52,13 +52,20 @@ export function BtnWithTooltips({
|
||||
tooltipsContent: string;
|
||||
disabled?: boolean;
|
||||
}) {
|
||||
// the onClick handler is on the container, so screen readers can safely ignore the inner button
|
||||
// this prevents the label from being read twice
|
||||
return (
|
||||
<div className="tooltip tooltip-bottom" data-tip={tooltipsContent}>
|
||||
<div
|
||||
className="tooltip tooltip-bottom"
|
||||
data-tip={tooltipsContent}
|
||||
role="button"
|
||||
onClick={onClick}
|
||||
>
|
||||
<button
|
||||
className={`${className ?? ''} flex items-center justify-center`}
|
||||
onClick={onClick}
|
||||
disabled={disabled}
|
||||
onMouseLeave={onMouseLeave}
|
||||
aria-hidden={true}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
|
||||
Reference in New Issue
Block a user