Compare commits

...

9 Commits

Author SHA1 Message Date
Adrien Gallouët
69ffd89163 ggml-amx : fix ggml_amx_init() on generic Linux (#16049)
Generalize Linux check to `__linux__` to support non-glibc systems (like musl).
Also, return `false` on unknown/untested OS.

Without this commit, the code compiles (with warnings) but fails:

    register_backend: registered backend CPU (1 devices)
    register_device: registered device CPU (Intel(R) Xeon(R) Platinum 8488C)
    build: 6487 (51c4cac6) with x86_64-linux-musl-gcc (GCC) 15.1.0 for x86_64-linux-musl (debug)
    system info: n_threads = 8, n_threads_batch = 8, total_threads = 16
    ....
    print_info: n_ctx_orig_yarn  = 262144
    print_info: rope_finetuned   = unknown
    print_info: model type       = 4B
    Illegal instruction (core dumped)

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-09-18 23:07:26 +02:00
Adrien Gallouët
246c0d9c79 cmake : fix static linking for OpenMP on Unix-like systems (#16031)
When compiling with GGML_STATIC=ON, the build process would produce a
binary that was still dynamically linked to OpenMP. This defeats the
purpose of a static build:

    $ cmake -B build \
            -DBUILD_SHARED_LIBS=OFF \
            -DLLAMA_CURL=OFF \
            -DGGML_CCACHE=OFF \
            -DGGML_NATIVE=OFF \
            -DGGML_STATIC=ON

    $ ldd llama-server
            linux-vdso.so.1 (0x0000e1a434e3b000)
            libgomp.so.1 => /lib/aarch64-linux-gnu/libgomp.so.1 (0x0000e1a4345a0000)
            libstdc++.so.6 => /lib/aarch64-linux-gnu/libstdc++.so.6 (0x0000e1a434300000)
            libm.so.6 => /lib/aarch64-linux-gnu/libm.so.6 (0x0000e1a434240000)
            libgcc_s.so.1 => /lib/aarch64-linux-gnu/libgcc_s.so.1 (0x0000e1a434200000)
            libc.so.6 => /lib/aarch64-linux-gnu/libc.so.6 (0x0000e1a434030000)
            /lib/ld-linux-aarch64.so.1 (0x0000e1a434df0000)

This commit resolves the issue by modifying `CMAKE_FIND_LIBRARY_SUFFIXES`
to prioritize `.a` files, forcing CMake to link the static version of
the library.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-09-18 23:07:18 +02:00
Shawn Gu
3edd87cd05 opencl: optimize mxfp4 kernels (#16037)
- flatten mxfp4 and packed fp4->fp16 bit-wise convert function (replace lut)
- MoE kernel optimizations

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2025-09-18 12:03:34 -07:00
Jeff Bolz
c0b45097c3 rename optimize_graph to graph_optimize (#16082) 2025-09-18 13:46:17 -05:00
Bowen Han
38dbdf4c05 CUDA: Optimize PAD_REFLECT_1D (#15957)
* CUDA: Optimize PAD_REFLECT_1D
feat: add more test cases for PAD_REFLECT_1D

* use fast_div to improve performance

* Apply suggestion from JohannesGaessler

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Apply suggestion from JohannesGaessler

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* optimize

* use a concise expression to further speedup the cuda kernel

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2025-09-18 20:26:03 +02:00
Johannes Gäßler
368560a1e3 CUDA: fix compilation on CC 6.0 (#16091) 2025-09-18 19:28:32 +02:00
Eric Curtin
4ca088b036 Add resumable downloads for llama-server model loading (#15963)
- Implement resumable downloads in common_download_file_single function
- Add detection of partial download files (.downloadInProgress)
- Check server support for HTTP Range requests via Accept-Ranges header
- Implement HTTP Range request with "bytes=<start>-" header
- Open files in append mode when resuming vs create mode for new downloads

Signed-off-by: Eric Curtin <eric.curtin@docker.com>
2025-09-18 16:22:50 +01:00
Georgi Gerganov
703f9e32c4 metal : use function constants for mul_mv_ext kernels (#16074)
* metal : use function constants for mul_mv_ext kernels

ggml-ci

* metal : remove NW template argument

ggml-ci

* metal : adjust constants

ggml-ci
2025-09-18 16:28:41 +03:00
Sigbjørn Skjæret
ad6bd9083b cuda : add missing F32<->I32 entries in ggml_cuda_cpy_fn (#16060) 2025-09-18 13:28:22 +02:00
30 changed files with 1239 additions and 455 deletions

View File

@@ -57,12 +57,32 @@ static std::string read_file(const std::string & fname) {
}
static void write_file(const std::string & fname, const std::string & content) {
std::ofstream file(fname);
const std::string fname_tmp = fname + ".tmp";
std::ofstream file(fname_tmp);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
}
file << content;
file.close();
try {
file << content;
file.close();
// Makes write atomic
if (rename(fname_tmp.c_str(), fname.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str());
// If rename fails, try to delete the temporary file
if (remove(fname_tmp.c_str()) != 0) {
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
}
}
} catch (...) {
// If anything fails, try to delete the temporary file
if (remove(fname_tmp.c_str()) != 0) {
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
}
throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str()));
}
}
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
@@ -217,250 +237,294 @@ struct curl_slist_ptr {
}
};
#define CURL_MAX_RETRY 3
#define CURL_RETRY_DELAY_SECONDS 2
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds, const char * method_name) {
int remaining_attempts = max_attempts;
while (remaining_attempts > 0) {
LOG_INF("%s: %s %s (attempt %d of %d)...\n", __func__ , method_name, url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
CURLcode res = curl_easy_perform(curl);
if (res == CURLE_OK) {
return true;
}
int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
remaining_attempts--;
if (remaining_attempts == 0) break;
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
static CURLcode common_curl_perf(CURL * curl) {
CURLcode res = curl_easy_perform(curl);
if (res != CURLE_OK) {
LOG_ERR("%s: curl_easy_perform() failed\n", __func__);
}
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
return false;
return res;
}
// download one single file from remote URL to local path
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) {
// Check if the file already exists locally
auto file_exists = std::filesystem::exists(path);
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
// Send a HEAD request to retrieve the etag and last-modified headers
struct common_load_model_from_url_headers {
std::string etag;
std::string last_modified;
std::string accept_ranges;
};
if (file_exists) {
if (offline) {
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
return true; // skip verification/downloading
struct FILE_deleter {
void operator()(FILE * f) const { fclose(f); }
};
static size_t common_header_callback(char * buffer, size_t, size_t n_items, void * userdata) {
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
static std::regex header_regex("([^:]+): (.*)\r\n");
static std::regex etag_regex("ETag", std::regex_constants::icase);
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
static std::regex accept_ranges_regex("Accept-Ranges", std::regex_constants::icase);
std::string header(buffer, n_items);
std::smatch match;
if (std::regex_match(header, match, header_regex)) {
const std::string & key = match[1];
const std::string & value = match[2];
if (std::regex_match(key, match, etag_regex)) {
headers->etag = value;
} else if (std::regex_match(key, match, last_modified_regex)) {
headers->last_modified = value;
} else if (std::regex_match(key, match, accept_ranges_regex)) {
headers->accept_ranges = value;
}
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata.at("etag");
}
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
}
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
} else {
if (offline) {
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
return false;
}
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
// Send a HEAD request to retrieve the etag and last-modified headers
struct common_load_model_from_url_headers {
std::string etag;
std::string last_modified;
};
return n_items;
}
common_load_model_from_url_headers headers;
bool head_request_ok = false;
bool should_download = !file_exists; // by default, we should download if the file does not exist
static size_t common_write_callback(void * data, size_t size, size_t nmemb, void * fd) {
return std::fwrite(data, size, nmemb, static_cast<FILE *>(fd));
}
// Initialize libcurl
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
// helper function to hide password in URL
static std::string llama_download_hide_password_in_url(const std::string & url) {
// Use regex to match and replace the user[:password]@ pattern in URLs
// Pattern: scheme://[user[:password]@]host[...]
static const std::regex url_regex(R"(^(?:[A-Za-z][A-Za-z0-9+.-]://)(?:[^/@]+@)?.$)");
std::smatch match;
if (std::regex_match(url, match, url_regex)) {
// match[1] = scheme (e.g., "https://")
// match[2] = user[:password]@ part
// match[3] = rest of URL (host and path)
return match[1].str() + "********@" + match[3].str();
}
return url; // No credentials found or malformed URL
}
static void common_curl_easy_setopt_head(CURL * curl, const std::string & url) {
// Set the URL, allow to follow http redirection
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
# if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows.
curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
# endif
curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress
curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, common_header_callback);
}
static void common_curl_easy_setopt_get(CURL * curl) {
curl_easy_setopt(curl, CURLOPT_NOBODY, 0L);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, common_write_callback);
// display download progress
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
}
static bool common_pull_file(CURL * curl, const std::string & path_temporary) {
if (std::filesystem::exists(path_temporary)) {
const std::string partial_size = std::to_string(std::filesystem::file_size(path_temporary));
LOG_INF("%s: server supports range requests, resuming download from byte %s\n", __func__, partial_size.c_str());
const std::string range_str = partial_size + "-";
curl_easy_setopt(curl, CURLOPT_RANGE, range_str.c_str());
}
// Always open file in append mode could be resuming
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "ab"));
if (!outfile) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str());
return false;
}
common_curl_easy_setopt_get(curl);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile.get());
return common_curl_perf(curl) == CURLE_OK;
}
static bool common_download_head(CURL * curl,
curl_slist_ptr & http_headers,
const std::string & url,
const std::string & bearer_token) {
if (!curl) {
LOG_ERR("%s: error initializing libcurl\n", __func__);
return false;
}
// Set the URL, allow to follow http redirection
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
// Check if hf-token or bearer-token was specified
if (!bearer_token.empty()) {
std::string auth_header = "Authorization: Bearer " + bearer_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
}
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
#if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows.
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
static std::regex header_regex("([^:]+): (.*)\r\n");
static std::regex etag_regex("ETag", std::regex_constants::icase);
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
std::string header(buffer, n_items);
std::smatch match;
if (std::regex_match(header, match, header_regex)) {
const std::string & key = match[1];
const std::string & value = match[2];
if (std::regex_match(key, match, etag_regex)) {
headers->etag = value;
} else if (std::regex_match(key, match, last_modified_regex)) {
headers->last_modified = value;
}
}
return n_items;
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
// we only allow retrying once for HEAD requests
// this is for the use case of using running offline (no internet), retrying can be annoying
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD");
if (!was_perform_successful) {
head_request_ok = false;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
}
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code == 200) {
head_request_ok = true;
} else {
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
head_request_ok = false;
}
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, http_headers.ptr);
common_curl_easy_setopt_head(curl, url);
return common_curl_perf(curl) == CURLE_OK;
}
// if head_request_ok is false, we don't have the etag or last-modified headers
// we leave should_download as-is, which is true if the file does not exist
if (head_request_ok) {
// check if ETag or Last-Modified headers are different
// if it is, we need to download the file again
if (!etag.empty() && etag != headers.etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
should_download = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
}
}
// download one single file from remote URL to local path
static bool common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline) {
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
for (int i = 0; i < max_attempts; ++i) {
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
std::string etag;
std::string last_modified;
if (should_download) {
std::string path_temporary = path + ".downloadInProgress";
// Check if the file already exists locally
const auto file_exists = std::filesystem::exists(path);
if (file_exists) {
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
if (offline) {
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
return true; // skip verification/downloading
}
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
metadata.dump().c_str());
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata.at("etag");
}
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
}
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
} else {
if (offline) {
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
return false;
}
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
// Set the output file
bool head_request_ok = false;
bool should_download = !file_exists; // by default, we should download if the file does not exist
struct FILE_deleter {
void operator()(FILE * f) const {
fclose(f);
}
};
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
if (!outfile) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
return false;
}
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
return fwrite(data, size, nmemb, (FILE *)fd);
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
// display download progress
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
std::size_t protocol_pos = url.find("://");
if (protocol_pos == std::string::npos) {
return url; // Malformed URL
}
std::size_t at_pos = url.find('@', protocol_pos + 3);
if (at_pos == std::string::npos) {
return url; // No password in URL
}
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
};
// start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET");
// Initialize libcurl
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
common_load_model_from_url_headers headers;
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
curl_slist_ptr http_headers;
const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
if (!was_perform_successful) {
return false;
head_request_ok = false;
}
long http_code = 0;
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code == 200) {
head_request_ok = true;
} else {
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
head_request_ok = false;
}
// Causes file to be closed explicitly here before we rename it.
outfile.reset();
// Write the updated JSON metadata file.
metadata.update({
{"url", url},
{"etag", headers.etag},
{"lastModified", headers.last_modified}
});
write_file(metadata_path, metadata.dump(4));
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
// if head_request_ok is false, we don't have the etag or last-modified headers
// we leave should_download as-is, which is true if the file does not exist
bool should_download_from_scratch = false;
if (head_request_ok) {
// check if ETag or Last-Modified headers are different
// if it is, we need to download the file again
if (!etag.empty() && etag != headers.etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(),
headers.etag.c_str());
should_download = true;
should_download_from_scratch = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
should_download_from_scratch = true;
}
}
} else {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
const bool accept_ranges_supported = !headers.accept_ranges.empty() && headers.accept_ranges != "none";
if (should_download) {
if (file_exists &&
!accept_ranges_supported) { // Resumable downloads not supported, delete and start again.
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
}
}
const std::string path_temporary = path + ".downloadInProgress";
if (should_download_from_scratch) {
if (std::filesystem::exists(path_temporary)) {
if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
return false;
}
}
if (std::filesystem::exists(path)) {
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
}
}
}
// Write the updated JSON metadata file.
metadata.update({
{ "url", url },
{ "etag", headers.etag },
{ "lastModified", headers.last_modified }
});
write_file(metadata_path, metadata.dump(4));
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
// start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
__func__, llama_download_hide_password_in_url(url).c_str(), path_temporary.c_str(),
headers.etag.c_str(), headers.last_modified.c_str());
const bool was_pull_successful = common_pull_file(curl.get(), path_temporary);
if (!was_pull_successful) {
if (i + 1 < max_attempts) {
const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
} else {
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
}
continue;
}
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
}
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
}
} else {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
}
break;
}
return true;
@@ -770,7 +834,7 @@ static std::string common_docker_get_token(const std::string & repo) {
}
static std::string common_docker_resolve_model(const std::string & docker) {
// Parse ai/smollm2:135M-Q4_K_M
// Parse ai/smollm2:135M-Q4_0
size_t colon_pos = docker.find(':');
std::string repo, tag;
if (colon_pos != std::string::npos) {

View File

@@ -114,6 +114,9 @@ message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}")
if (NOT MSVC)
if (GGML_STATIC)
if (UNIX AND NOT APPLE)
set(CMAKE_FIND_LIBRARY_SUFFIXES ".a;.so")
endif()
add_link_options(-static)
if (MINGW)
add_link_options(-static-libgcc -static-libstdc++)

View File

@@ -116,7 +116,7 @@ extern "C" {
void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
// (optional) sort/optimize the nodes in the graph
void (*optimize_graph) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
void (*graph_optimize) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
};
struct ggml_backend {

View File

@@ -463,10 +463,10 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
backend->iface.event_wait(backend, event);
}
static void ggml_backend_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
static void ggml_backend_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_ASSERT(backend);
if (backend->iface.optimize_graph != NULL) {
backend->iface.optimize_graph(backend, cgraph);
if (backend->iface.graph_optimize != NULL) {
backend->iface.graph_optimize(backend, cgraph);
}
}
@@ -1307,7 +1307,7 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
// Optimize this split of the graph. This needs to happen before we make graph_copy,
// so they are in sync.
ggml_backend_optimize_graph(sched->backends[split->backend_id], &split->graph);
ggml_backend_graph_optimize(sched->backends[split->backend_id], &split->graph);
// add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
for (int j = 0; j < split->n_inputs; j++) {

View File

@@ -270,7 +270,7 @@ static struct ggml_backend_i blas_backend_i = {
/* .graph_compute = */ ggml_backend_blas_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
/* .graph_optimize = */ NULL,
};
static ggml_guid_t ggml_backend_blas_guid(void) {

View File

@@ -2756,7 +2756,7 @@ static const ggml_backend_i ggml_backend_cann_interface = {
/* .graph_compute = */ ggml_backend_cann_graph_compute,
/* .event_record = */ ggml_backend_cann_event_record,
/* .event_wait = */ ggml_backend_cann_event_wait,
/* .optimize_graph = */ NULL,
/* .graph_optimize = */ NULL,
};
/**

View File

@@ -7,7 +7,7 @@
#include "ggml-cpu.h"
#include "traits.h"
#if defined(__gnu_linux__)
#if defined(__linux__)
#include <sys/syscall.h>
#include <unistd.h>
#endif
@@ -186,7 +186,7 @@ static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_ty
#define XFEATURE_XTILEDATA 18
static bool ggml_amx_init() {
#if defined(__gnu_linux__)
#if defined(__linux__)
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
fprintf(stderr, "AMX is not ready to be used!\n");
return false;
@@ -194,6 +194,8 @@ static bool ggml_amx_init() {
return true;
#elif defined(_WIN32)
return true;
#else
return false;
#endif
}

View File

@@ -190,7 +190,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
/* .graph_optimize = */ NULL,
};
static ggml_guid_t ggml_backend_cpu_guid(void) {

View File

@@ -652,6 +652,14 @@ static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fa
return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
}
// Calculate both division and modulo at once, returns <n/divisor, n%divisor>
static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) {
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
const uint32_t div_val = fastdiv(n, fastdiv_values);
const uint32_t mod_val = n - div_val * fastdiv_values.z;
return make_uint2(div_val, mod_val);
}
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
static __device__ __forceinline__ float get_alibi_slope(

View File

@@ -441,6 +441,10 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));

View File

@@ -35,7 +35,6 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
switch (D) {
case 64:
case 128:
return 128;
case 256:
return ncols <= 16 ? 128 : 64;
default:
@@ -86,7 +85,6 @@ static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols
switch (D) {
case 64:
case 128:
return 128;
case 256:
return ncols <= 16 ? 128 : 64;
default:

View File

@@ -3140,7 +3140,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
/* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait,
/* .optimize_graph = */ NULL,
/* .graph_optimize = */ NULL,
};
static ggml_guid_t ggml_backend_cuda_guid() {

View File

@@ -1,82 +1,89 @@
#include "pad_reflect_1d.cuh"
static __global__ void pad_reflect_1d_kernel_f32(
const void * __restrict__ src0,
void * __restrict__ dst,
const int64_t ne0,
const int64_t ne00,
const int64_t ne01,
const int64_t ne02,
const int64_t ne03,
const int64_t nb00,
const int64_t nb01,
const int64_t nb02,
const int64_t nb03,
const int64_t nb0,
const int64_t nb1,
const int64_t nb2,
const int64_t nb3,
const int p0,
const int p1) {
static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
pad_reflect_1d_kernel_f32(
const void * __restrict__ src0,
void * __restrict__ dst,
const int64_t ne0,
const int64_t ne00,
const uint3 ne01,
const int64_t ne02,
const int64_t ne03,
const int64_t nb00,
const int64_t nb01,
const int64_t nb02,
const int64_t nb03,
const int64_t nb0,
const int64_t nb1,
const int64_t nb2,
const int64_t nb3,
const int p0,
const int p1) {
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01);
const int64_t tile1 = div_mod_packed.y; // i1
const int64_t tile0 = div_mod_packed.x; // nth i0 tile
const int64_t i1 = tile1;
const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
// ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)
if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {
return;
}
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
float value;
const int64_t rel_i0 = i0 - p0; // relative i0 in src0
int64_t src_idx;
if (i0 < p0) {
// Left padding - reflect
value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
} else if (i0 < ne0 - p1) {
// Middle - copy
value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
} else {
// Right padding - reflect
int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
value = *(const float *)(src0_ptr + src_idx * nb00);
}
*(float *)(dst_ptr + i0 * nb0) = value;
if (rel_i0 < 0) {
// Left padding - reflect
src_idx = -rel_i0;
} else if (rel_i0 < ne00) {
// Middle - copy
src_idx = rel_i0;
} else {
// Right padding - reflect
src_idx = 2 * ne00 - 2 - rel_i0;
}
const float value = *(const float *) (src0_ptr + src_idx * nb00);
*(float *) (dst_ptr + i0 * nb0) = value;
}
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int32_t * opts = (const int32_t *) dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];
const int p0 = opts[0];
const int p1 = opts[1];
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const uint3 ne01_packed = init_fastdiv_values(ne01);
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne0 = dst->ne[0];
// sanity: padded length matches
GGML_ASSERT(ne0 == ne00 + p0 + p1);
const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
const dim3 grid_dims(ne01, ne02, ne03);
constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0
// grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
// grid.y covers i2: [ne02]
// grid.z covers i3: [ne03]
const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);
const dim3 block_dims((unsigned) bx, 1, 1);
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
src0->data, dst->data,
ne0, ne00, ne01, ne02, ne03,
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
p0, p1
);
src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1);
}

View File

@@ -414,19 +414,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int r1ptg) {
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
char base[256];
char name[256];
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
return res;
}
@@ -608,7 +615,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
};
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_nsg=%d", base, nsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
@@ -824,7 +831,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
};
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_nsg=%d", base, nsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
@@ -923,11 +930,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
dk,
dv);
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
"flash_attn_ext",
ggml_type_name(op->src[1]->type),
dk,
dv,
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
base,
has_mask,
has_sinks,
has_bias,
@@ -985,11 +989,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
dk,
dv);
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
"flash_attn_ext_vec",
ggml_type_name(op->src[1]->type),
dk,
dv,
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
base,
has_mask,
has_sinks,
has_bias,
@@ -1033,7 +1034,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
char name[256];
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {

View File

@@ -114,7 +114,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int r1ptg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);

View File

@@ -35,13 +35,13 @@
#define N_R0_Q3_K 2
#define N_SG_Q3_K 2
#define N_R0_Q4_K 4
#define N_R0_Q4_K 2
#define N_SG_Q4_K 2
#define N_R0_Q5_K 2
#define N_SG_Q5_K 2
#define N_R0_Q6_K 1
#define N_R0_Q6_K 2
#define N_SG_Q6_K 2
#define N_R0_IQ1_S 4
@@ -374,9 +374,6 @@ typedef struct {
int32_t ne1;
int16_t r2;
int16_t r3;
int16_t nsg;
int16_t nxpsg;
int16_t r1ptg;
} ggml_metal_kargs_mul_mv_ext;
typedef struct {

View File

@@ -1444,7 +1444,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
GGML_ABORT("unsupported ne11");
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, r1ptg);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
ggml_metal_kargs_mul_mv_ext args = {
/*.ne00 =*/ ne00,
@@ -1465,9 +1465,6 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
/*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
/*.nsg =*/ nsg,
/*.nxpsg =*/ nxpsg,
/*.r1ptg =*/ r1ptg,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);

View File

@@ -447,7 +447,7 @@ static ggml_backend_i ggml_backend_metal_i = {
// https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ ggml_backend_metal_graph_optimize,
/* .graph_optimize = */ ggml_backend_metal_graph_optimize,
};
static ggml_guid_t ggml_backend_metal_guid(void) {

View File

@@ -2843,7 +2843,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
}
template<short NR0, short NW>
template<short NR0>
static inline void helper_mv_reduce_and_write(
device float * dst_f32,
float sumf[NR0],
@@ -2852,6 +2852,8 @@ static inline void helper_mv_reduce_and_write(
ushort tiisg,
ushort sgitg,
threadgroup char * shmem) {
constexpr short NW = N_SIMDWIDTH;
threadgroup float * shmem_f32[NR0];
for (short row = 0; row < NR0; ++row) {
@@ -2883,9 +2885,10 @@ static inline void helper_mv_reduce_and_write(
}
}
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
template<typename block_q_type, short NR0, short NW, typename args_t>
template<typename block_q_type, short NR0, typename args_t>
void mul_vec_q_n_f32_impl(
args_t args,
device const char * src0,
@@ -2897,6 +2900,7 @@ void mul_vec_q_n_f32_impl(
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
constexpr short NW = N_SIMDWIDTH;
constexpr short NQ = 16;
const int nb = args.ne00/QK4_0;
@@ -2961,7 +2965,7 @@ void mul_vec_q_n_f32_impl(
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
//helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
//helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
for (int row = 0; row < NR0; ++row) {
const float tot = simd_sum(sumf[row]);
@@ -2981,7 +2985,7 @@ kernel void kernel_mul_mv_q4_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
@@ -2993,7 +2997,7 @@ kernel void kernel_mul_mv_q4_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
@@ -3005,7 +3009,7 @@ kernel void kernel_mul_mv_q5_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
@@ -3017,10 +3021,10 @@ kernel void kernel_mul_mv_q5_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<short NR0, short NW, typename args_t>
template<short NR0, typename args_t>
void kernel_mul_mv_q8_0_f32_impl(
args_t args,
device const char * src0,
@@ -3032,6 +3036,7 @@ void kernel_mul_mv_q8_0_f32_impl(
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
constexpr short NW = N_SIMDWIDTH;
constexpr short NQ = 8;
const int nb = args.ne00/QK8_0;
@@ -3090,7 +3095,7 @@ void kernel_mul_mv_q8_0_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
}
[[host_name("kernel_mul_mv_q8_0_f32")]]
@@ -3103,12 +3108,12 @@ kernel void kernel_mul_mv_q8_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
// mat-vec kernel processing in chunks of float4
// chpb - chunks per quantization block
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
template<short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
void kernel_mul_mv_ext_q4_f32_impl(
constant ggml_metal_kargs_mul_mv_ext & args,
device const char * src0,
@@ -3117,6 +3122,9 @@ void kernel_mul_mv_ext_q4_f32_impl(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short NSG = FC_mul_mv_nsg;
const short nxpsg = FC_mul_mv_nxpsg;
const short chpt = 4; // chunks per thread
//const short nxpsg = (32);
@@ -3125,7 +3133,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
const short tx = tiisg%nxpsg;
const short ty = tiisg/nxpsg;
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
const int i11 = tgpig.y*r1ptg;
const int i1m = tgpig.z;
@@ -3208,7 +3216,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
}
// mat-vec kernel processing in chunks of float4x4
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
template<short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
void kernel_mul_mv_ext_q4x4_f32_impl(
constant ggml_metal_kargs_mul_mv_ext & args,
device const char * src0,
@@ -3217,6 +3225,9 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short NSG = FC_mul_mv_nsg;
const short nxpsg = FC_mul_mv_nxpsg;
const short chpt = 1;
//const short nxpsg = (32);
@@ -3225,7 +3236,7 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
const short tx = tiisg%nxpsg;
const short ty = tiisg/nxpsg;
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
const int i11 = tgpig.y*r1ptg;
const int i1m = tgpig.z;
@@ -3322,12 +3333,7 @@ kernel void kernel_mul_mv_ext_q4_f32_disp(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
switch (args.nxpsg) {
case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
}
kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
}
template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
@@ -3339,12 +3345,7 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
switch (args.nxpsg) {
case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
}
kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
}
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
@@ -3410,7 +3411,7 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
template<typename T0, typename T1, short NR0, short NW, typename args_t>
template<typename T0, typename T1, short NR0, typename args_t>
void kernel_mul_mv_t_t_impl(
args_t args,
device const char * src0,
@@ -3422,6 +3423,7 @@ void kernel_mul_mv_t_t_impl(
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
constexpr short NW = N_SIMDWIDTH;
constexpr short NB = 32;
constexpr short NF = 8;
@@ -3486,10 +3488,10 @@ void kernel_mul_mv_t_t_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
}
template<typename T0, typename T1, short NR0, short NW>
template<typename T0, typename T1, short NR0>
kernel void kernel_mul_mv_t_t(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
@@ -3499,20 +3501,20 @@ kernel void kernel_mul_mv_t_t(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_t_t_impl<T0, T1, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_t_t_impl<T0, T1, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t;
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F>) mul_mv_t_t;
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F>;
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F>;
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F>;
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F>;
#endif
template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW, typename args_t>
template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
void kernel_mul_mv_t_t_4_impl(
args_t args,
device const char * src0,
@@ -3524,6 +3526,7 @@ void kernel_mul_mv_t_t_4_impl(
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
constexpr short NW = N_SIMDWIDTH;
constexpr short NB = 32;
constexpr short NF = 16;
constexpr short NF4 = NF/4;
@@ -3591,10 +3594,10 @@ void kernel_mul_mv_t_t_4_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
}
template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW>
template<typename T0, typename T04, typename T1, typename T14, short NR0>
kernel void kernel_mul_mv_t_t_4(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
@@ -3604,17 +3607,17 @@ kernel void kernel_mul_mv_t_t_4(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t_4;
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>) mul_mv_t_t_4;
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F>;
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F>;
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F, N_SIMDWIDTH>;
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F>;
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F>;
#endif
#define N_MV_T_T 4
@@ -5966,7 +5969,7 @@ kernel void kernel_concat(
}
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_q2_K_f32_impl(
args_t args,
device const char * src0,
@@ -6068,10 +6071,10 @@ kernel void kernel_mul_mv_q2_K_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_q3_K_f32_impl(
args_t args,
device const char * src0,
@@ -6233,10 +6236,10 @@ kernel void kernel_mul_mv_q3_K_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_q4_K_f32_impl(
args_t args,
device const char * src0,
@@ -6248,9 +6251,9 @@ void kernel_mul_mv_q4_K_f32_impl(
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
constexpr uint16_t kmask1 = 0x3f3f;
constexpr uint16_t kmask2 = 0x0f0f;
constexpr uint16_t kmask3 = 0xc0c0;
const short ix = tiisg/8; // 0...3
const short it = tiisg%8; // 0...7
@@ -6309,7 +6312,7 @@ void kernel_mul_mv_q4_K_f32_impl(
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (short i = 0; i < 4; ++i) {
FOR_UNROLL (short i = 0; i < 4; ++i) {
acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
@@ -6320,14 +6323,11 @@ void kernel_mul_mv_q4_K_f32_impl(
acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
}
float dall = dh[0];
float dmin = dh[1];
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
q1 += args.nb01/2;
sc += args.nb01/2;
@@ -6357,10 +6357,10 @@ kernel void kernel_mul_mv_q4_K_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_q5_K_f32_impl(
args_t args,
device const char * src0,
@@ -6393,9 +6393,9 @@ void kernel_mul_mv_q5_K_f32_impl(
float yl[16], yh[16];
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
constexpr uint16_t kmask1 = 0x3f3f;
constexpr uint16_t kmask2 = 0x0f0f;
constexpr uint16_t kmask3 = 0xc0c0;
const short tid = tiisg/4;
const short ix = tiisg%4;
@@ -6441,7 +6441,7 @@ void kernel_mul_mv_q5_K_f32_impl(
float4 acc1 = {0.f};
float4 acc2 = {0.f};
for (short l = 0; l < 8; ++l) {
FOR_UNROLL (short l = 0; l < 8; ++l) {
uint8_t h = qh[l];
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
@@ -6452,13 +6452,12 @@ void kernel_mul_mv_q5_K_f32_impl(
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
}
const float dall = dh[0];
const float dmin = dh[1];
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
q1 += args.nb01;
qh += args.nb01;
@@ -6489,10 +6488,10 @@ kernel void kernel_mul_mv_q5_K_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_q6_K_f32_impl(
args_t args,
device const char * src0,
@@ -6504,10 +6503,10 @@ void kernel_mul_mv_q6_K_f32_impl(
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
const uint8_t kmask3 = 0x30;
const uint8_t kmask4 = 0xC0;
constexpr uint8_t kmask1 = 0x03;
constexpr uint8_t kmask2 = 0x0C;
constexpr uint8_t kmask3 = 0x30;
constexpr uint8_t kmask4 = 0xC0;
const int nb = args.ne00/QK_K;
@@ -6558,18 +6557,16 @@ void kernel_mul_mv_q6_K_f32_impl(
}
for (short row = 0; row < nr0; ++row) {
const float dall = dh[0];
float4 sums = {0.f, 0.f, 0.f, 0.f};
for (short l = 0; l < 4; ++l) {
FOR_UNROLL (short l = 0; l < 4; ++l) {
sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
}
sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
q1 += args.nb01;
q2 += args.nb01;
@@ -6599,12 +6596,12 @@ kernel void kernel_mul_mv_q6_K_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
// ======================= "True" 2-bit
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_iq2_xxs_f32_impl(
args_t args,
device const char * src0,
@@ -6709,10 +6706,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_iq2_xs_f32_impl(
args_t args,
device const char * src0,
@@ -6828,10 +6825,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_iq3_xxs_f32_impl(
args_t args,
device const char * src0,
@@ -6940,10 +6937,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_iq3_s_f32_impl(
args_t args,
device const char * src0,
@@ -7052,10 +7049,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_iq2_s_f32_impl(
args_t args,
device const char * src0,
@@ -7165,10 +7162,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_iq1_s_f32_impl(
args_t args,
device const char * src0,
@@ -7264,10 +7261,10 @@ kernel void kernel_mul_mv_iq1_s_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_iq1_m_f32_impl(
args_t args,
device const char * src0,
@@ -7373,10 +7370,10 @@ kernel void kernel_mul_mv_iq1_m_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_iq4_nl_f32_impl(
args_t args,
device const char * src0,
@@ -7480,10 +7477,10 @@ kernel void kernel_mul_mv_iq4_nl_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_iq4_xs_f32_impl(
args_t args,
device const char * src0,
@@ -7587,10 +7584,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nw, typename args_t>
template<int nr0, typename args_t>
void kernel_mul_mv_mxfp4_f32_impl(
args_t args,
device const char * src0,
@@ -7677,7 +7674,7 @@ kernel void kernel_mul_mv_mxfp4_f32(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -8353,7 +8350,7 @@ void mmv_fn(
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, N_SIMDWIDTH, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
template<mul_mv_impl_fn_t impl_fn>
kernel void kernel_mul_mv_id(
@@ -8418,44 +8415,44 @@ kernel void kernel_mul_mv_id(
sgitg);
}
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_t;
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>) kernel_mul_mv_id_t;
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_4_t;
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>) kernel_mul_mv_id_4_t;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F>>>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F>>>;
#endif
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>;
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F>>>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F>>>;
#endif
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1>>>;
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>;
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K>>>;
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K>>>;
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K>>>;
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K>>>;
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K>>>;
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S>>>;
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M>>>;
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS>>>;
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS>>>;
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS>>>;
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S>>>;
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S>>>;
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL>>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS>>>;
kernel void kernel_pool_2d_max_f32(
constant ggml_metal_kargs_pool_2d & args,

View File

@@ -83,8 +83,10 @@ set(GGML_OPENCL_KERNELS
mul_mv_q4_0_f32_1d_16x_flat
mul_mv_q6_k
mul_mv_mxfp4_f32
mul_mv_mxfp4_f32_flat
mul_mv_id_q4_0_f32_8x_flat
mul_mv_id_mxfp4_f32
mul_mv_id_mxfp4_f32_flat
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul

View File

@@ -368,6 +368,7 @@ struct ggml_backend_opencl_context {
cl_program program_mul_mv_q4_0_f32_1d_16x_flat;
cl_program program_mul_mv_q6_K;
cl_program program_mul_mv_mxfp4_f32;
cl_program program_mul_mv_mxfp4_f32_flat;
cl_program program_mul_mv_f16_f16;
cl_program program_mul_mv_f16_f32_1row;
cl_program program_mul_mv_f16_f32_l4;
@@ -402,6 +403,7 @@ struct ggml_backend_opencl_context {
cl_program program_tsembd;
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
cl_program program_mul_mv_id_mxfp4_f32;
cl_program program_mul_mv_id_mxfp4_f32_flat;
cl_program program_mul_mm_f32_f32_l4_lm;
cl_program program_mul_mm_f16_f32_l4_lm;
@@ -447,11 +449,12 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mat_f16_f32_tiled;
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
cl_kernel kernel_convert_block_q4_0_noshuffle;
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
cl_kernel kernel_mul_mv_q6_K_f32;
cl_kernel kernel_mul_mv_mxfp4_f32;
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
cl_kernel kernel_argsort_f32_i32;
cl_kernel kernel_sum_rows_f32;
@@ -469,6 +472,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
cl_kernel kernel_mul_mv_id_mxfp4_f32;
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
@@ -765,6 +769,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
GGML_LOG_CONT(".");
}
@@ -1002,6 +1008,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mv_mxfp4_f32_flat
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_mxfp4_f32_flat.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_mxfp4_f32_flat.cl");
#endif
backend_ctx->program_mul_mv_mxfp4_f32_flat =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32_flat, "kernel_mul_mv_mxfp4_f32_flat", &err), err));
GGML_LOG_CONT(".");
}
// mul_mv_f16_f16
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1727,6 +1749,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mv_id_mxfp4_f32_flat
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_id_mxfp4_f32_flat.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_id_mxfp4_f32_flat.cl");
#endif
backend_ctx->program_mul_mv_id_mxfp4_f32_flat =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32_flat, "kernel_mul_mv_id_mxfp4_f32_flat", &err), err));
GGML_LOG_CONT(".");
}
// Adreno kernels
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
// transpose
@@ -2391,6 +2429,51 @@ struct ggml_tensor_extra_cl_q4_0 {
}
};
struct ggml_tensor_extra_cl_mxfp4 {
// Quantized values.
cl_mem q = nullptr;
// Quantized values in image1d_buffer_t.
cl_mem q_img = nullptr;
// Scales in E8M0.
cl_mem e = nullptr;
// Scales in image1d_buffer_t.
cl_mem e_img = nullptr;
// Size of quantized values.
size_t size_q = 0;
// Size of scales.
size_t size_e = 0;
~ggml_tensor_extra_cl_mxfp4() {
reset();
}
void reset() {
// q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
// They must be properly released so that the original buffer can be
// properly released to avoid memory leak.
if (q != nullptr) {
CL_CHECK(clReleaseMemObject(q));
q = nullptr;
}
if (e != nullptr) {
CL_CHECK(clReleaseMemObject(e));
e = nullptr;
}
if (q != nullptr) {
CL_CHECK(clReleaseMemObject(q_img));
q = nullptr;
}
// Currently, q_img and d_img are only initialized when SMALL_ALLOC is
// enabled. They point to the images in ggml_backend_opencl_buffer_context.
// So, there is no need to release them here.
// TODO: initialize them for non SMALL_PATH path, or remove them.
q_img = nullptr;
e_img = nullptr;
size_q = 0;
size_e = 0;
}
};
//------------------------------------------------------------------------------
// Backend API
//------------------------------------------------------------------------------
@@ -2838,7 +2921,7 @@ static ggml_backend_i ggml_backend_opencl_i = {
/* .graph_compute = */ ggml_backend_opencl_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
/* .graph_optimize = */ NULL,
};
ggml_backend_t ggml_backend_opencl_init(void) {
@@ -2894,6 +2977,12 @@ struct ggml_backend_opencl_buffer_context {
for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) {
delete e;
}
for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) {
delete e;
}
for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {
delete e;
}
}
ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() {
@@ -2926,6 +3015,21 @@ struct ggml_backend_opencl_buffer_context {
return extra;
}
ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() {
ggml_tensor_extra_cl_mxfp4 * extra;
if (temp_tensor_extras_mxfp4.empty()) {
extra = new ggml_tensor_extra_cl_mxfp4();
} else {
extra = temp_tensor_extras_mxfp4.back();
temp_tensor_extras_mxfp4.pop_back();
}
temp_tensor_extras_mxfp4_in_use.push_back(extra);
extra->reset();
return extra;
}
void reset() {
for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {
temp_tensor_extras.push_back(e);
@@ -2936,6 +3040,11 @@ struct ggml_backend_opencl_buffer_context {
temp_tensor_extras_q4_0.push_back(e);
}
temp_tensor_extras_q4_0_in_use.clear();
for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {
temp_tensor_extras_mxfp4.push_back(e);
}
temp_tensor_extras_mxfp4_in_use.clear();
}
// Pools for extras. Available extras are in `temp_tensor_extras`. Extras
@@ -2947,6 +3056,8 @@ struct ggml_backend_opencl_buffer_context {
std::vector<ggml_tensor_extra_cl *> temp_tensor_extras_in_use;
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0;
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0_in_use;
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4;
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;
// The buffer_context is initially created by ggml_backend_buft_alloc_buffer
// before any tensor is initialized (at the beginning of alloc_tensor_range).
@@ -3289,6 +3400,76 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
}
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
return;
}
if (tensor->type == GGML_TYPE_MXFP4) {
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
// Allocate the new extra and create aliases from the original.
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
ggml_tensor_extra_cl_mxfp4 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_mxfp4();
size_t size_e = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(char);
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
GGML_ASSERT(size_e + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
CL_CHECK(clEnqueueWriteBuffer(
queue, data_device, CL_TRUE, 0,
ggml_nbytes(tensor), data, 0, NULL, NULL));
// The original tensor memory is divided into scales and quants, i.e.,
// we first store scales, then quants.
cl_buffer_region region;
// Create subbuffer for scales.
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
region.size = size_e;
extra->e = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
auto previous_origin = region.origin;
// Create subbuffer for quants.
region.origin = align_to(previous_origin + size_e, backend_ctx->alignment);
region.size = size_q;
extra->q = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {64, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clReleaseMemObject(data_device));
// Create image for Q
cl_image_format img_format_q = {CL_RG, CL_UNSIGNED_INT32};
cl_image_desc img_desc_q = {
CL_MEM_OBJECT_IMAGE1D_BUFFER,
static_cast<size_t>(ggml_nelements(tensor)/32*2),
0, 0, 0, 0, 0, 0, 0,
{ extra->q }
};
extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
tensor->extra = extra;
return;
}
#endif // GGML_OPENCL_SOA_Q
@@ -3337,6 +3518,31 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clEnqueueReadBuffer(
queue, data_device, CL_TRUE, offset,
size, data, 0, NULL, NULL));
CL_CHECK(clReleaseMemObject(data_device));
return;
} else if (tensor->type == GGML_TYPE_MXFP4) {
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra;
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
@@ -3658,6 +3864,19 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL));
CL_CHECK(clFinish(queue));
} else if (tensor->type == GGML_TYPE_MXFP4) {
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra;
GGML_ASSERT(extra);
size_t size_q = ggml_nelements(tensor)/QK_MXFP4 * QK_MXFP4/2;
size_t size_e = ggml_nelements(tensor)/QK_MXFP4 * sizeof(char);
GGML_ASSERT(size_q + size_e == ggml_nbytes(tensor));
buf_q = malloc(size_q);
buf_d = malloc(size_e);
CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL));
CL_CHECK(clFinish(queue));
} else {
// Read out the tensor from GPU memory.
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
@@ -6048,6 +6267,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
#ifdef GGML_OPENCL_SOA_Q
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
#endif
const int ne00 = src0 ? src0->ne[0] : 0;
@@ -6752,6 +6972,45 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
break;
case GGML_TYPE_MXFP4: {
#ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat;
cl_mem q;
if (backend_ctx->gpu_family == INTEL) {
nth0 = 16;
nth1 = 2;
ndst = nth1*2;
q = extra0_mxfp4->q;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 64;
nth1 = 2;
ndst = nth1*2;
q = extra0_mxfp4->q_img;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_mxfp4->e));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3));
#else
kernel = backend_ctx->kernel_mul_mv_mxfp4_f32;
if (backend_ctx->gpu_family == INTEL) {
@@ -6785,6 +7044,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float)*nth0,nullptr));
#endif
break;
}
default:
@@ -6850,8 +7110,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
cl_ulong offset2 = extra2->offset + src2->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
GGML_UNUSED(offset0);
#ifdef GGML_OPENCL_SOA_Q
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
#endif
const int ne00 = src0->ne[0];
@@ -6940,6 +7203,51 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
break;
}
case GGML_TYPE_MXFP4: {
#ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat;
cl_mem q;
if (backend_ctx->gpu_family == INTEL) {
sgs = 16;
nsg = 2;
ndst = 2;
q = extra0_mxfp4->q;
} else if (backend_ctx->gpu_family == ADRENO) {
sgs = 64;
nsg = 1;
ndst = 4;
q = extra0_mxfp4->q_img;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_mxfp4->e));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne20));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne21));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb21));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3));
#else // GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32;
if (backend_ctx->gpu_family == INTEL) {
@@ -6979,7 +7287,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs,nullptr));
#endif // GGML_OPENCL_SOA_Q
break;
}
default:

View File

@@ -116,3 +116,49 @@ kernel void kernel_convert_block_q4_0_noshuffle(
#endif
}
}
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
#define QK_MXFP4 32
struct block_mxfp4 {
uchar e; // E8M0
uchar qs[QK_MXFP4 / 2];
};
//------------------------------------------------------------------------------
// kernel_convert_block_mxfp4
// Convert the block_mxfp4 format to 2 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_mxfp4(
global struct block_mxfp4 * src0,
global uchar * dst_q,
global uchar * dst_e
) {
global struct block_mxfp4 * b = (global struct block_mxfp4 *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK_MXFP4 / 2 * get_global_id(0);
global uchar * e = (global uchar *) dst_e + get_global_id(0);
*e = b->e;
for (int i = 0; i < QK_MXFP4 / 2; ++i) {
q[i] = b->qs[i];
}
}
kernel void kernel_restore_block_mxfp4(
global uchar * src_q,
global half * src_e,
global struct block_mxfp4 * dst
) {
global struct block_mxfp4 * b = (global struct block_mxfp4 *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + QK_MXFP4 / 2 * get_global_id(0);
global uchar * e = (global uchar *) src_e + get_global_id(0);
b->e = *e;
for (int i = 0; i < QK_MXFP4 / 2; ++i) {
b->qs[i] = q[i];
}
}

View File

@@ -0,0 +1,176 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK_MXFP4 32
static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {
ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;
fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;
fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;
fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;
bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;
bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;
bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;
fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;
fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;
fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;
fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;
sign_a.lo = (fp4x4 << 12) & 0x8000;
sign_a.hi = (fp4x4 << 8) & 0x8000;
sign_b.lo = (fp4x4 << 4) & 0x8000;
sign_b.hi = fp4x4 & 0x8000;
fp16_packed_a = sign_a + bias_a + fp16_packed_a;
fp16_packed_b = sign_b + bias_b + fp16_packed_b;
return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
#ifdef INTEL_GPU
#define N_R0_MXFP4 2 // number of rows each subgroup works on
#define N_SG_MXFP4 2 // number of subgroups in a work group
#define N_SIMDWIDTH 16 // subgroup size
#elif defined (ADRENO_GPU)
#define N_R0_MXFP4 4
#define N_SG_MXFP4 1
#define N_SIMDWIDTH 64
#define SRC0Q_IMG
#endif
kernel void kernel_mul_mv_id_mxfp4_f32_flat(
#ifdef SRC0Q_IMG
__read_only image1d_buffer_t src0_q,
#else
global uchar * src0_q,
#endif
global uchar * src0_e,
global uchar * src1,
ulong offset1,
global uchar * src2,
ulong offset2,
global uchar * dst,
ulong offsetd,
int ne00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne11,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne20,
int ne21,
ulong nb21,
int ne0,
int ne1,
int r2,
int r3
) {
dst = dst + offsetd;
const int iid1 = get_group_id(2) / ne20;
const int idx = get_group_id(2) % ne20;
uint i02 = ((global uint *) (src2 + offset2 + iid1 * nb21))[idx];
int i11 = idx % ne11;
int nb = ne00 / QK_MXFP4;
uint src0_off = i02*nb02;
src0_off /= 17; // 17 = sizeof(block_mxfp4)
src0_e = src0_e + src0_off;
dst = dst + (idx * ne0 + iid1 * ne1 * ne0) * sizeof(float);
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
uint offset_src0 = first_row*nb01;
offset_src0 /= 17; // 17 = sizeof(block_mxfp4)
#ifdef SRC0Q_IMG
ulong offset_q = src0_off + offset_src0;
#else
src0_q = src0_q + src0_off*16;
global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0;
#endif
global uchar * x_e = src0_e + offset_src0;
const short ix = get_sub_group_local_id() >> 1;
const short it = get_sub_group_local_id() & 1;
float sumf[N_R0_MXFP4] = {0.f};
src1 = src1 + offset1 + i11 * nb11 + iid1 * nb12;
global float * y = (global float *) (src1 + r1 * nb11);
global float * yb = y + ix * QK_MXFP4 + it * 8;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH / 2) {
global float4 * y4 = (global float4 *)yb;
#pragma unroll
for (short row = 0; row < N_R0_MXFP4; row++) {
uchar xb_e = x_e[row * nb + ib];
#ifdef SRC0Q_IMG
ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy);
#else
ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it));
#endif
half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0);
half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1);
float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2);
fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3);
acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
}
yb += (N_SIMDWIDTH / 2) * QK_MXFP4;
}
global float * dst_f32 = (global float *)dst + (ulong)r1 * ne0;
for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
float sum_all = sub_group_reduce_add(sumf[row]);
if (get_sub_group_local_id() == 0) {
dst_f32[first_row + row] = sum_all;
}
}
}

View File

@@ -0,0 +1,167 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK_MXFP4 32
static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {
ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;
fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;
fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;
fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;
bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;
bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;
bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;
fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;
fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;
fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;
fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;
sign_a.lo = (fp4x4 << 12) & 0x8000;
sign_a.hi = (fp4x4 << 8) & 0x8000;
sign_b.lo = (fp4x4 << 4) & 0x8000;
sign_b.hi = fp4x4 & 0x8000;
fp16_packed_a = sign_a + bias_a + fp16_packed_a;
fp16_packed_b = sign_b + bias_b + fp16_packed_b;
return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
#ifdef INTEL_GPU
#define N_R0_MXFP4 2 // number of rows each subgroup works on
#define N_SG_MXFP4 2 // number of subgroups in a work group
#define N_SIMDWIDTH 16 // subgroup size
#elif defined (ADRENO_GPU)
#define N_R0_MXFP4 2
#define N_SG_MXFP4 2
#define N_SIMDWIDTH 64
#define SRC0Q_IMG
#endif
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_mxfp4_f32_flat(
#ifdef SRC0Q_IMG
__read_only image1d_buffer_t src0_q,
#else
global uchar * src0_q,
#endif
global uchar * src0_e,
global uchar * src1,
ulong offset1,
global uchar * dst,
ulong offsetd,
int ne00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = src1 + offset1;
dst = dst + offsetd;
int nb = ne00 / QK_MXFP4;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
uint i12 = im % ne12;
uint i13 = im / ne12;
uint offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
// 17 = sizeof(block_mxfp4)
offset_src0 /= 17;
#ifdef SRC0Q_IMG
ulong offset_q = offset_src0;
#else
global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0;
#endif
global uchar * x_e = src0_e + offset_src0;
ulong offset_src1 = r1 * nb11 + i12 * nb12 + i13 * nb13;
global float * y = (global float *)(src1 + offset_src1);
const short ix = get_sub_group_local_id() >> 1; // 0...15
const short it = get_sub_group_local_id() & 1; // 0 or 1
float sumf[N_R0_MXFP4] = {0.f};
global float * yb = y + ix * QK_MXFP4 + it * 8;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
global float4 * y4 = (global float4 *)yb;
#pragma unroll
for (short row = 0; row < N_R0_MXFP4; row++) {
uchar xb_e = x_e[row * nb + ib];
#ifdef SRC0Q_IMG
ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy);
#else
ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it));
#endif
half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0);
half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1);
float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2);
fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3);
acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
}
yb += (N_SIMDWIDTH/2) * QK_MXFP4;
}
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
float sum_all = sub_group_reduce_add(sumf[row]);
if (get_sub_group_local_id() == 0) {
dst_f32[first_row + row] = sum_all;
}
}
}

View File

@@ -795,7 +795,7 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
/* .graph_optimize = */ NULL,
};
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {

View File

@@ -4073,7 +4073,7 @@ static ggml_backend_i ggml_backend_sycl_interface = {
/* .graph_compute = */ ggml_backend_sycl_graph_compute,
/* .event_record = */ ggml_backend_sycl_event_record,
/* .event_wait = */ ggml_backend_sycl_event_wait,
/* .optimize_graph = */ NULL,
/* .graph_optimize = */ NULL,
};
static ggml_guid_t ggml_backend_sycl_guid() {

View File

@@ -593,7 +593,7 @@ struct vk_device_struct {
bool disable_fusion;
bool disable_host_visible_vidmem;
bool allow_sysmem_fallback;
bool disable_optimize_graph;
bool disable_graph_optimize;
#ifdef GGML_VULKAN_MEMORY_DEBUG
std::unique_ptr<vk_memory_logger> memory_logger;
@@ -3624,8 +3624,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
const char* GGML_VK_DISABLE_OPTIMIZE_GRAPH = getenv("GGML_VK_DISABLE_OPTIMIZE_GRAPH");
device->disable_optimize_graph = GGML_VK_DISABLE_OPTIMIZE_GRAPH != nullptr;
const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv("GGML_VK_DISABLE_GRAPH_OPTIMIZE");
device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr;
bool fp16_storage = false;
bool fp16_compute = false;
@@ -11914,12 +11914,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
}
// Sort the graph for improved parallelism.
static void ggml_vk_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph)
static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph)
{
VK_LOG_DEBUG("ggml_vk_optimize_graph(" << graph->n_nodes << " nodes)");
VK_LOG_DEBUG("ggml_vk_graph_optimize(" << graph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
if (ctx->device->disable_optimize_graph) {
if (ctx->device->disable_graph_optimize) {
return;
}
@@ -12053,7 +12053,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
/* .graph_compute = */ ggml_backend_vk_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ ggml_vk_optimize_graph,
/* .graph_optimize = */ ggml_vk_graph_optimize,
};
static ggml_guid_t ggml_backend_vk_guid() {

View File

@@ -823,7 +823,7 @@ static ggml_backend_i ggml_backend_webgpu_i = {
/* .graph_compute = */ ggml_backend_webgpu_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
/* .graph_optimize = */ NULL,
};
/* End GGML Backend Interface */

View File

@@ -574,7 +574,7 @@ static ggml_backend_i ggml_backend_zdnn_i = {
/* .graph_compute = */ ggml_backend_zdnn_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
/* .graph_optimize = */ NULL,
};
static ggml_guid_t ggml_backend_zdnn_guid(void) {

View File

@@ -6507,6 +6507,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_pad_ext());
test_cases.emplace_back(new test_pad_reflect_1d());
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
test_cases.emplace_back(new test_roll());
test_cases.emplace_back(new test_arange());
test_cases.emplace_back(new test_timestep_embedding());
@@ -6645,6 +6646,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {512, 34, 2, 1}));
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 1, 1}));
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 4, 1}));
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true));