mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-03-17 16:44:07 +00:00
kleidiai: add data type check to get_tensor_traits (#20639)
* kleidiai: add data type check to get_tensor_traits * Added check for F16 data type into get_tensor_traits path with input data not in ggml_backend_cpu_kleidiai_buffer_type format (unsupported for Q4/8) Signed-off-by: Martin Klacer <martin.klacer@arm.com> Change-Id: I9aca4b9b8d669d35db6f1dbcc4e080b1919b1de7 * updated ggml/src/ggml-cpu/kleidiai/kleidiai.cpp updated kleidiai.cpp file as per suggestion Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Signed-off-by: Martin Klacer <martin.klacer@arm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@@ -1473,10 +1473,12 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
||||
} else {
|
||||
if (op->src[0]->type != GGML_TYPE_F16) {
|
||||
return nullptr;
|
||||
}
|
||||
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
|
||||
const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
|
||||
const bool has_kernel = slot_total > 0;
|
||||
if (has_kernel && op->src[1]->ne[1] > 1) {
|
||||
if (slot_total > 0 && op->src[1]->ne[1] > 1) {
|
||||
if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
||||
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
||||
return nullptr;
|
||||
|
||||
Reference in New Issue
Block a user