mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-07 01:24:24 +00:00
Compare commits
9 Commits
xsn/qwen3n
...
b6517
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
69ffd89163 | ||
|
|
246c0d9c79 | ||
|
|
3edd87cd05 | ||
|
|
c0b45097c3 | ||
|
|
38dbdf4c05 | ||
|
|
368560a1e3 | ||
|
|
4ca088b036 | ||
|
|
703f9e32c4 | ||
|
|
ad6bd9083b |
484
common/arg.cpp
484
common/arg.cpp
@@ -57,12 +57,32 @@ static std::string read_file(const std::string & fname) {
|
||||
}
|
||||
|
||||
static void write_file(const std::string & fname, const std::string & content) {
|
||||
std::ofstream file(fname);
|
||||
const std::string fname_tmp = fname + ".tmp";
|
||||
std::ofstream file(fname_tmp);
|
||||
if (!file) {
|
||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
|
||||
}
|
||||
file << content;
|
||||
file.close();
|
||||
|
||||
try {
|
||||
file << content;
|
||||
file.close();
|
||||
|
||||
// Makes write atomic
|
||||
if (rename(fname_tmp.c_str(), fname.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str());
|
||||
// If rename fails, try to delete the temporary file
|
||||
if (remove(fname_tmp.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
|
||||
}
|
||||
}
|
||||
} catch (...) {
|
||||
// If anything fails, try to delete the temporary file
|
||||
if (remove(fname_tmp.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
|
||||
}
|
||||
|
||||
throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
common_arg & common_arg::set_examples(std::initializer_list<enum llama_example> examples) {
|
||||
@@ -217,250 +237,294 @@ struct curl_slist_ptr {
|
||||
}
|
||||
};
|
||||
|
||||
#define CURL_MAX_RETRY 3
|
||||
#define CURL_RETRY_DELAY_SECONDS 2
|
||||
|
||||
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds, const char * method_name) {
|
||||
int remaining_attempts = max_attempts;
|
||||
|
||||
while (remaining_attempts > 0) {
|
||||
LOG_INF("%s: %s %s (attempt %d of %d)...\n", __func__ , method_name, url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
|
||||
|
||||
CURLcode res = curl_easy_perform(curl);
|
||||
if (res == CURLE_OK) {
|
||||
return true;
|
||||
}
|
||||
|
||||
int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
|
||||
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
|
||||
|
||||
remaining_attempts--;
|
||||
if (remaining_attempts == 0) break;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
|
||||
static CURLcode common_curl_perf(CURL * curl) {
|
||||
CURLcode res = curl_easy_perform(curl);
|
||||
if (res != CURLE_OK) {
|
||||
LOG_ERR("%s: curl_easy_perform() failed\n", __func__);
|
||||
}
|
||||
|
||||
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
|
||||
|
||||
return false;
|
||||
return res;
|
||||
}
|
||||
|
||||
// download one single file from remote URL to local path
|
||||
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) {
|
||||
// Check if the file already exists locally
|
||||
auto file_exists = std::filesystem::exists(path);
|
||||
|
||||
// If the file exists, check its JSON metadata companion file.
|
||||
std::string metadata_path = path + ".json";
|
||||
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
|
||||
// Send a HEAD request to retrieve the etag and last-modified headers
|
||||
struct common_load_model_from_url_headers {
|
||||
std::string etag;
|
||||
std::string last_modified;
|
||||
std::string accept_ranges;
|
||||
};
|
||||
|
||||
if (file_exists) {
|
||||
if (offline) {
|
||||
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
|
||||
return true; // skip verification/downloading
|
||||
struct FILE_deleter {
|
||||
void operator()(FILE * f) const { fclose(f); }
|
||||
};
|
||||
|
||||
static size_t common_header_callback(char * buffer, size_t, size_t n_items, void * userdata) {
|
||||
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
|
||||
static std::regex header_regex("([^:]+): (.*)\r\n");
|
||||
static std::regex etag_regex("ETag", std::regex_constants::icase);
|
||||
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
|
||||
static std::regex accept_ranges_regex("Accept-Ranges", std::regex_constants::icase);
|
||||
std::string header(buffer, n_items);
|
||||
std::smatch match;
|
||||
if (std::regex_match(header, match, header_regex)) {
|
||||
const std::string & key = match[1];
|
||||
const std::string & value = match[2];
|
||||
if (std::regex_match(key, match, etag_regex)) {
|
||||
headers->etag = value;
|
||||
} else if (std::regex_match(key, match, last_modified_regex)) {
|
||||
headers->last_modified = value;
|
||||
} else if (std::regex_match(key, match, accept_ranges_regex)) {
|
||||
headers->accept_ranges = value;
|
||||
}
|
||||
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
|
||||
std::ifstream metadata_in(metadata_path);
|
||||
if (metadata_in.good()) {
|
||||
try {
|
||||
metadata_in >> metadata;
|
||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
|
||||
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
|
||||
etag = metadata.at("etag");
|
||||
}
|
||||
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
|
||||
last_modified = metadata.at("lastModified");
|
||||
}
|
||||
} catch (const nlohmann::json::exception & e) {
|
||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||
}
|
||||
}
|
||||
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
|
||||
} else {
|
||||
if (offline) {
|
||||
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
// Send a HEAD request to retrieve the etag and last-modified headers
|
||||
struct common_load_model_from_url_headers {
|
||||
std::string etag;
|
||||
std::string last_modified;
|
||||
};
|
||||
return n_items;
|
||||
}
|
||||
|
||||
common_load_model_from_url_headers headers;
|
||||
bool head_request_ok = false;
|
||||
bool should_download = !file_exists; // by default, we should download if the file does not exist
|
||||
static size_t common_write_callback(void * data, size_t size, size_t nmemb, void * fd) {
|
||||
return std::fwrite(data, size, nmemb, static_cast<FILE *>(fd));
|
||||
}
|
||||
|
||||
// Initialize libcurl
|
||||
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
||||
curl_slist_ptr http_headers;
|
||||
// helper function to hide password in URL
|
||||
static std::string llama_download_hide_password_in_url(const std::string & url) {
|
||||
// Use regex to match and replace the user[:password]@ pattern in URLs
|
||||
// Pattern: scheme://[user[:password]@]host[...]
|
||||
static const std::regex url_regex(R"(^(?:[A-Za-z][A-Za-z0-9+.-]://)(?:[^/@]+@)?.$)");
|
||||
std::smatch match;
|
||||
|
||||
if (std::regex_match(url, match, url_regex)) {
|
||||
// match[1] = scheme (e.g., "https://")
|
||||
// match[2] = user[:password]@ part
|
||||
// match[3] = rest of URL (host and path)
|
||||
return match[1].str() + "********@" + match[3].str();
|
||||
}
|
||||
|
||||
return url; // No credentials found or malformed URL
|
||||
}
|
||||
|
||||
static void common_curl_easy_setopt_head(CURL * curl, const std::string & url) {
|
||||
// Set the URL, allow to follow http redirection
|
||||
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
|
||||
# if defined(_WIN32)
|
||||
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
||||
// operating system. Currently implemented under MS-Windows.
|
||||
curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
||||
# endif
|
||||
|
||||
curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
|
||||
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress
|
||||
curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, common_header_callback);
|
||||
}
|
||||
|
||||
static void common_curl_easy_setopt_get(CURL * curl) {
|
||||
curl_easy_setopt(curl, CURLOPT_NOBODY, 0L);
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, common_write_callback);
|
||||
|
||||
// display download progress
|
||||
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
|
||||
}
|
||||
|
||||
static bool common_pull_file(CURL * curl, const std::string & path_temporary) {
|
||||
if (std::filesystem::exists(path_temporary)) {
|
||||
const std::string partial_size = std::to_string(std::filesystem::file_size(path_temporary));
|
||||
LOG_INF("%s: server supports range requests, resuming download from byte %s\n", __func__, partial_size.c_str());
|
||||
const std::string range_str = partial_size + "-";
|
||||
curl_easy_setopt(curl, CURLOPT_RANGE, range_str.c_str());
|
||||
}
|
||||
|
||||
// Always open file in append mode could be resuming
|
||||
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "ab"));
|
||||
if (!outfile) {
|
||||
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
common_curl_easy_setopt_get(curl);
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile.get());
|
||||
|
||||
return common_curl_perf(curl) == CURLE_OK;
|
||||
}
|
||||
|
||||
static bool common_download_head(CURL * curl,
|
||||
curl_slist_ptr & http_headers,
|
||||
const std::string & url,
|
||||
const std::string & bearer_token) {
|
||||
if (!curl) {
|
||||
LOG_ERR("%s: error initializing libcurl\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Set the URL, allow to follow http redirection
|
||||
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
||||
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
|
||||
// Check if hf-token or bearer-token was specified
|
||||
if (!bearer_token.empty()) {
|
||||
std::string auth_header = "Authorization: Bearer " + bearer_token;
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
||||
}
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
|
||||
|
||||
#if defined(_WIN32)
|
||||
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
||||
// operating system. Currently implemented under MS-Windows.
|
||||
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
||||
#endif
|
||||
|
||||
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
|
||||
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
|
||||
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
|
||||
|
||||
static std::regex header_regex("([^:]+): (.*)\r\n");
|
||||
static std::regex etag_regex("ETag", std::regex_constants::icase);
|
||||
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
|
||||
|
||||
std::string header(buffer, n_items);
|
||||
std::smatch match;
|
||||
if (std::regex_match(header, match, header_regex)) {
|
||||
const std::string & key = match[1];
|
||||
const std::string & value = match[2];
|
||||
if (std::regex_match(key, match, etag_regex)) {
|
||||
headers->etag = value;
|
||||
} else if (std::regex_match(key, match, last_modified_regex)) {
|
||||
headers->last_modified = value;
|
||||
}
|
||||
}
|
||||
return n_items;
|
||||
};
|
||||
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
|
||||
|
||||
// we only allow retrying once for HEAD requests
|
||||
// this is for the use case of using running offline (no internet), retrying can be annoying
|
||||
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD");
|
||||
if (!was_perform_successful) {
|
||||
head_request_ok = false;
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
||||
}
|
||||
|
||||
long http_code = 0;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code == 200) {
|
||||
head_request_ok = true;
|
||||
} else {
|
||||
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
|
||||
head_request_ok = false;
|
||||
}
|
||||
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, http_headers.ptr);
|
||||
common_curl_easy_setopt_head(curl, url);
|
||||
return common_curl_perf(curl) == CURLE_OK;
|
||||
}
|
||||
|
||||
// if head_request_ok is false, we don't have the etag or last-modified headers
|
||||
// we leave should_download as-is, which is true if the file does not exist
|
||||
if (head_request_ok) {
|
||||
// check if ETag or Last-Modified headers are different
|
||||
// if it is, we need to download the file again
|
||||
if (!etag.empty() && etag != headers.etag) {
|
||||
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
|
||||
should_download = true;
|
||||
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
|
||||
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
|
||||
should_download = true;
|
||||
}
|
||||
}
|
||||
// download one single file from remote URL to local path
|
||||
static bool common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline) {
|
||||
// If the file exists, check its JSON metadata companion file.
|
||||
std::string metadata_path = path + ".json";
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
|
||||
std::string etag;
|
||||
std::string last_modified;
|
||||
|
||||
if (should_download) {
|
||||
std::string path_temporary = path + ".downloadInProgress";
|
||||
// Check if the file already exists locally
|
||||
const auto file_exists = std::filesystem::exists(path);
|
||||
if (file_exists) {
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
if (offline) {
|
||||
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
|
||||
return true; // skip verification/downloading
|
||||
}
|
||||
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
|
||||
std::ifstream metadata_in(metadata_path);
|
||||
if (metadata_in.good()) {
|
||||
try {
|
||||
metadata_in >> metadata;
|
||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
|
||||
metadata.dump().c_str());
|
||||
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
|
||||
etag = metadata.at("etag");
|
||||
}
|
||||
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
|
||||
last_modified = metadata.at("lastModified");
|
||||
}
|
||||
} catch (const nlohmann::json::exception & e) {
|
||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||
}
|
||||
}
|
||||
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
|
||||
} else {
|
||||
if (offline) {
|
||||
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
// Set the output file
|
||||
bool head_request_ok = false;
|
||||
bool should_download = !file_exists; // by default, we should download if the file does not exist
|
||||
|
||||
struct FILE_deleter {
|
||||
void operator()(FILE * f) const {
|
||||
fclose(f);
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
|
||||
if (!outfile) {
|
||||
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
|
||||
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
|
||||
return fwrite(data, size, nmemb, (FILE *)fd);
|
||||
};
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
|
||||
|
||||
// display download progress
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
|
||||
|
||||
// helper function to hide password in URL
|
||||
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
|
||||
std::size_t protocol_pos = url.find("://");
|
||||
if (protocol_pos == std::string::npos) {
|
||||
return url; // Malformed URL
|
||||
}
|
||||
|
||||
std::size_t at_pos = url.find('@', protocol_pos + 3);
|
||||
if (at_pos == std::string::npos) {
|
||||
return url; // No password in URL
|
||||
}
|
||||
|
||||
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
|
||||
};
|
||||
|
||||
// start the download
|
||||
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
|
||||
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
|
||||
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET");
|
||||
// Initialize libcurl
|
||||
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
||||
common_load_model_from_url_headers headers;
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
|
||||
curl_slist_ptr http_headers;
|
||||
const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
|
||||
if (!was_perform_successful) {
|
||||
return false;
|
||||
head_request_ok = false;
|
||||
}
|
||||
|
||||
long http_code = 0;
|
||||
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code < 200 || http_code >= 400) {
|
||||
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
|
||||
return false;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code == 200) {
|
||||
head_request_ok = true;
|
||||
} else {
|
||||
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
|
||||
head_request_ok = false;
|
||||
}
|
||||
|
||||
// Causes file to be closed explicitly here before we rename it.
|
||||
outfile.reset();
|
||||
|
||||
// Write the updated JSON metadata file.
|
||||
metadata.update({
|
||||
{"url", url},
|
||||
{"etag", headers.etag},
|
||||
{"lastModified", headers.last_modified}
|
||||
});
|
||||
write_file(metadata_path, metadata.dump(4));
|
||||
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
|
||||
|
||||
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return false;
|
||||
// if head_request_ok is false, we don't have the etag or last-modified headers
|
||||
// we leave should_download as-is, which is true if the file does not exist
|
||||
bool should_download_from_scratch = false;
|
||||
if (head_request_ok) {
|
||||
// check if ETag or Last-Modified headers are different
|
||||
// if it is, we need to download the file again
|
||||
if (!etag.empty() && etag != headers.etag) {
|
||||
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(),
|
||||
headers.etag.c_str());
|
||||
should_download = true;
|
||||
should_download_from_scratch = true;
|
||||
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
|
||||
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
|
||||
last_modified.c_str(), headers.last_modified.c_str());
|
||||
should_download = true;
|
||||
should_download_from_scratch = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
|
||||
const bool accept_ranges_supported = !headers.accept_ranges.empty() && headers.accept_ranges != "none";
|
||||
if (should_download) {
|
||||
if (file_exists &&
|
||||
!accept_ranges_supported) { // Resumable downloads not supported, delete and start again.
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string path_temporary = path + ".downloadInProgress";
|
||||
if (should_download_from_scratch) {
|
||||
if (std::filesystem::exists(path_temporary)) {
|
||||
if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::filesystem::exists(path)) {
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write the updated JSON metadata file.
|
||||
metadata.update({
|
||||
{ "url", url },
|
||||
{ "etag", headers.etag },
|
||||
{ "lastModified", headers.last_modified }
|
||||
});
|
||||
write_file(metadata_path, metadata.dump(4));
|
||||
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
|
||||
|
||||
// start the download
|
||||
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
|
||||
__func__, llama_download_hide_password_in_url(url).c_str(), path_temporary.c_str(),
|
||||
headers.etag.c_str(), headers.last_modified.c_str());
|
||||
const bool was_pull_successful = common_pull_file(curl.get(), path_temporary);
|
||||
if (!was_pull_successful) {
|
||||
if (i + 1 < max_attempts) {
|
||||
const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
|
||||
LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
|
||||
} else {
|
||||
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
long http_code = 0;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code < 200 || http_code >= 400) {
|
||||
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -770,7 +834,7 @@ static std::string common_docker_get_token(const std::string & repo) {
|
||||
}
|
||||
|
||||
static std::string common_docker_resolve_model(const std::string & docker) {
|
||||
// Parse ai/smollm2:135M-Q4_K_M
|
||||
// Parse ai/smollm2:135M-Q4_0
|
||||
size_t colon_pos = docker.find(':');
|
||||
std::string repo, tag;
|
||||
if (colon_pos != std::string::npos) {
|
||||
|
||||
@@ -114,6 +114,9 @@ message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}")
|
||||
|
||||
if (NOT MSVC)
|
||||
if (GGML_STATIC)
|
||||
if (UNIX AND NOT APPLE)
|
||||
set(CMAKE_FIND_LIBRARY_SUFFIXES ".a;.so")
|
||||
endif()
|
||||
add_link_options(-static)
|
||||
if (MINGW)
|
||||
add_link_options(-static-libgcc -static-libstdc++)
|
||||
|
||||
@@ -116,7 +116,7 @@ extern "C" {
|
||||
void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
|
||||
|
||||
// (optional) sort/optimize the nodes in the graph
|
||||
void (*optimize_graph) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||
void (*graph_optimize) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||
};
|
||||
|
||||
struct ggml_backend {
|
||||
|
||||
@@ -463,10 +463,10 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
|
||||
backend->iface.event_wait(backend, event);
|
||||
}
|
||||
|
||||
static void ggml_backend_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
static void ggml_backend_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
GGML_ASSERT(backend);
|
||||
if (backend->iface.optimize_graph != NULL) {
|
||||
backend->iface.optimize_graph(backend, cgraph);
|
||||
if (backend->iface.graph_optimize != NULL) {
|
||||
backend->iface.graph_optimize(backend, cgraph);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1307,7 +1307,7 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
|
||||
|
||||
// Optimize this split of the graph. This needs to happen before we make graph_copy,
|
||||
// so they are in sync.
|
||||
ggml_backend_optimize_graph(sched->backends[split->backend_id], &split->graph);
|
||||
ggml_backend_graph_optimize(sched->backends[split->backend_id], &split->graph);
|
||||
|
||||
// add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
|
||||
for (int j = 0; j < split->n_inputs; j++) {
|
||||
|
||||
@@ -270,7 +270,7 @@ static struct ggml_backend_i blas_backend_i = {
|
||||
/* .graph_compute = */ ggml_backend_blas_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .optimize_graph = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_blas_guid(void) {
|
||||
|
||||
@@ -2756,7 +2756,7 @@ static const ggml_backend_i ggml_backend_cann_interface = {
|
||||
/* .graph_compute = */ ggml_backend_cann_graph_compute,
|
||||
/* .event_record = */ ggml_backend_cann_event_record,
|
||||
/* .event_wait = */ ggml_backend_cann_event_wait,
|
||||
/* .optimize_graph = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "ggml-cpu.h"
|
||||
#include "traits.h"
|
||||
|
||||
#if defined(__gnu_linux__)
|
||||
#if defined(__linux__)
|
||||
#include <sys/syscall.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
@@ -186,7 +186,7 @@ static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_ty
|
||||
#define XFEATURE_XTILEDATA 18
|
||||
|
||||
static bool ggml_amx_init() {
|
||||
#if defined(__gnu_linux__)
|
||||
#if defined(__linux__)
|
||||
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
|
||||
fprintf(stderr, "AMX is not ready to be used!\n");
|
||||
return false;
|
||||
@@ -194,6 +194,8 @@ static bool ggml_amx_init() {
|
||||
return true;
|
||||
#elif defined(_WIN32)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -190,7 +190,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
|
||||
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .optimize_graph = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_cpu_guid(void) {
|
||||
|
||||
@@ -652,6 +652,14 @@ static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fa
|
||||
return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
|
||||
}
|
||||
|
||||
// Calculate both division and modulo at once, returns <n/divisor, n%divisor>
|
||||
static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) {
|
||||
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
|
||||
const uint32_t div_val = fastdiv(n, fastdiv_values);
|
||||
const uint32_t mod_val = n - div_val * fastdiv_values.z;
|
||||
return make_uint2(div_val, mod_val);
|
||||
}
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
|
||||
|
||||
static __device__ __forceinline__ float get_alibi_slope(
|
||||
|
||||
@@ -441,6 +441,10 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
|
||||
@@ -35,7 +35,6 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
return 128;
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
@@ -86,7 +85,6 @@ static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
return 128;
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
|
||||
@@ -3140,7 +3140,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
|
||||
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
|
||||
/* .event_record = */ ggml_backend_cuda_event_record,
|
||||
/* .event_wait = */ ggml_backend_cuda_event_wait,
|
||||
/* .optimize_graph = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_cuda_guid() {
|
||||
|
||||
@@ -1,82 +1,89 @@
|
||||
#include "pad_reflect_1d.cuh"
|
||||
|
||||
static __global__ void pad_reflect_1d_kernel_f32(
|
||||
const void * __restrict__ src0,
|
||||
void * __restrict__ dst,
|
||||
const int64_t ne0,
|
||||
const int64_t ne00,
|
||||
const int64_t ne01,
|
||||
const int64_t ne02,
|
||||
const int64_t ne03,
|
||||
const int64_t nb00,
|
||||
const int64_t nb01,
|
||||
const int64_t nb02,
|
||||
const int64_t nb03,
|
||||
const int64_t nb0,
|
||||
const int64_t nb1,
|
||||
const int64_t nb2,
|
||||
const int64_t nb3,
|
||||
const int p0,
|
||||
const int p1) {
|
||||
|
||||
static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
|
||||
pad_reflect_1d_kernel_f32(
|
||||
const void * __restrict__ src0,
|
||||
void * __restrict__ dst,
|
||||
const int64_t ne0,
|
||||
const int64_t ne00,
|
||||
const uint3 ne01,
|
||||
const int64_t ne02,
|
||||
const int64_t ne03,
|
||||
const int64_t nb00,
|
||||
const int64_t nb01,
|
||||
const int64_t nb02,
|
||||
const int64_t nb03,
|
||||
const int64_t nb0,
|
||||
const int64_t nb1,
|
||||
const int64_t nb2,
|
||||
const int64_t nb3,
|
||||
const int p0,
|
||||
const int p1) {
|
||||
const int64_t i3 = blockIdx.z;
|
||||
const int64_t i2 = blockIdx.y;
|
||||
const int64_t i1 = blockIdx.x;
|
||||
|
||||
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
|
||||
const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01);
|
||||
const int64_t tile1 = div_mod_packed.y; // i1
|
||||
const int64_t tile0 = div_mod_packed.x; // nth i0 tile
|
||||
const int64_t i1 = tile1;
|
||||
const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
|
||||
|
||||
// ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)
|
||||
if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {
|
||||
return;
|
||||
}
|
||||
|
||||
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
|
||||
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
|
||||
const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
|
||||
char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
|
||||
|
||||
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
||||
float value;
|
||||
const int64_t rel_i0 = i0 - p0; // relative i0 in src0
|
||||
int64_t src_idx;
|
||||
|
||||
if (i0 < p0) {
|
||||
// Left padding - reflect
|
||||
value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
|
||||
} else if (i0 < ne0 - p1) {
|
||||
// Middle - copy
|
||||
value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
|
||||
} else {
|
||||
// Right padding - reflect
|
||||
int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
|
||||
value = *(const float *)(src0_ptr + src_idx * nb00);
|
||||
}
|
||||
|
||||
*(float *)(dst_ptr + i0 * nb0) = value;
|
||||
if (rel_i0 < 0) {
|
||||
// Left padding - reflect
|
||||
src_idx = -rel_i0;
|
||||
} else if (rel_i0 < ne00) {
|
||||
// Middle - copy
|
||||
src_idx = rel_i0;
|
||||
} else {
|
||||
// Right padding - reflect
|
||||
src_idx = 2 * ne00 - 2 - rel_i0;
|
||||
}
|
||||
const float value = *(const float *) (src0_ptr + src_idx * nb00);
|
||||
*(float *) (dst_ptr + i0 * nb0) = value;
|
||||
}
|
||||
|
||||
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int32_t * opts = (const int32_t *) dst->op_params;
|
||||
const int p0 = opts[0];
|
||||
const int p1 = opts[1];
|
||||
const int p0 = opts[0];
|
||||
const int p1 = opts[1];
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const uint3 ne01_packed = init_fastdiv_values(ne01);
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
// sanity: padded length matches
|
||||
GGML_ASSERT(ne0 == ne00 + p0 + p1);
|
||||
|
||||
const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
|
||||
const dim3 grid_dims(ne01, ne02, ne03);
|
||||
constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
|
||||
const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0
|
||||
// grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
|
||||
// grid.y covers i2: [ne02]
|
||||
// grid.z covers i3: [ne03]
|
||||
const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);
|
||||
const dim3 block_dims((unsigned) bx, 1, 1);
|
||||
|
||||
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
|
||||
src0->data, dst->data,
|
||||
ne0, ne00, ne01, ne02, ne03,
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
p0, p1
|
||||
);
|
||||
src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1);
|
||||
}
|
||||
|
||||
@@ -414,19 +414,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int r1ptg) {
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
||||
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -608,7 +615,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
@@ -824,7 +831,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
@@ -923,11 +930,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
"flash_attn_ext",
|
||||
ggml_type_name(op->src[1]->type),
|
||||
dk,
|
||||
dv,
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
@@ -985,11 +989,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
"flash_attn_ext_vec",
|
||||
ggml_type_name(op->src[1]->type),
|
||||
dk,
|
||||
dv,
|
||||
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
@@ -1033,7 +1034,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
|
||||
snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
|
||||
snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
|
||||
@@ -114,7 +114,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int r1ptg);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
|
||||
|
||||
@@ -35,13 +35,13 @@
|
||||
#define N_R0_Q3_K 2
|
||||
#define N_SG_Q3_K 2
|
||||
|
||||
#define N_R0_Q4_K 4
|
||||
#define N_R0_Q4_K 2
|
||||
#define N_SG_Q4_K 2
|
||||
|
||||
#define N_R0_Q5_K 2
|
||||
#define N_SG_Q5_K 2
|
||||
|
||||
#define N_R0_Q6_K 1
|
||||
#define N_R0_Q6_K 2
|
||||
#define N_SG_Q6_K 2
|
||||
|
||||
#define N_R0_IQ1_S 4
|
||||
@@ -374,9 +374,6 @@ typedef struct {
|
||||
int32_t ne1;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
int16_t nsg;
|
||||
int16_t nxpsg;
|
||||
int16_t r1ptg;
|
||||
} ggml_metal_kargs_mul_mv_ext;
|
||||
|
||||
typedef struct {
|
||||
|
||||
@@ -1444,7 +1444,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ABORT("unsupported ne11");
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, r1ptg);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
|
||||
|
||||
ggml_metal_kargs_mul_mv_ext args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1465,9 +1465,6 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.r2 =*/ r2,
|
||||
/*.r3 =*/ r3,
|
||||
/*.nsg =*/ nsg,
|
||||
/*.nxpsg =*/ nxpsg,
|
||||
/*.r1ptg =*/ r1ptg,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
|
||||
@@ -447,7 +447,7 @@ static ggml_backend_i ggml_backend_metal_i = {
|
||||
// https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .optimize_graph = */ ggml_backend_metal_graph_optimize,
|
||||
/* .graph_optimize = */ ggml_backend_metal_graph_optimize,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_metal_guid(void) {
|
||||
|
||||
@@ -2843,7 +2843,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
||||
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
|
||||
}
|
||||
|
||||
template<short NR0, short NW>
|
||||
template<short NR0>
|
||||
static inline void helper_mv_reduce_and_write(
|
||||
device float * dst_f32,
|
||||
float sumf[NR0],
|
||||
@@ -2852,6 +2852,8 @@ static inline void helper_mv_reduce_and_write(
|
||||
ushort tiisg,
|
||||
ushort sgitg,
|
||||
threadgroup char * shmem) {
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
|
||||
threadgroup float * shmem_f32[NR0];
|
||||
|
||||
for (short row = 0; row < NR0; ++row) {
|
||||
@@ -2883,9 +2885,10 @@ static inline void helper_mv_reduce_and_write(
|
||||
}
|
||||
}
|
||||
|
||||
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
|
||||
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
|
||||
constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
|
||||
|
||||
template<typename block_q_type, short NR0, short NW, typename args_t>
|
||||
template<typename block_q_type, short NR0, typename args_t>
|
||||
void mul_vec_q_n_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -2897,6 +2900,7 @@ void mul_vec_q_n_f32_impl(
|
||||
ushort sgitg) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
constexpr short NQ = 16;
|
||||
|
||||
const int nb = args.ne00/QK4_0;
|
||||
@@ -2961,7 +2965,7 @@ void mul_vec_q_n_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
//helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
//helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
|
||||
for (int row = 0; row < NR0; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
@@ -2981,7 +2985,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_q4_1_f32(
|
||||
@@ -2993,7 +2997,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_q5_0_f32(
|
||||
@@ -3005,7 +3009,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_q5_1_f32(
|
||||
@@ -3017,10 +3021,10 @@ kernel void kernel_mul_mv_q5_1_f32(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<short NR0, short NW, typename args_t>
|
||||
template<short NR0, typename args_t>
|
||||
void kernel_mul_mv_q8_0_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -3032,6 +3036,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
||||
ushort sgitg) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
constexpr short NQ = 8;
|
||||
|
||||
const int nb = args.ne00/QK8_0;
|
||||
@@ -3090,7 +3095,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
||||
@@ -3103,12 +3108,12 @@ kernel void kernel_mul_mv_q8_0_f32(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
// mat-vec kernel processing in chunks of float4
|
||||
// chpb - chunks per quantization block
|
||||
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
|
||||
template<short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
|
||||
void kernel_mul_mv_ext_q4_f32_impl(
|
||||
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||
device const char * src0,
|
||||
@@ -3117,6 +3122,9 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
const short nxpsg = FC_mul_mv_nxpsg;
|
||||
|
||||
const short chpt = 4; // chunks per thread
|
||||
|
||||
//const short nxpsg = (32);
|
||||
@@ -3125,7 +3133,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
||||
const short tx = tiisg%nxpsg;
|
||||
const short ty = tiisg/nxpsg;
|
||||
|
||||
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
|
||||
const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
|
||||
const int i11 = tgpig.y*r1ptg;
|
||||
const int i1m = tgpig.z;
|
||||
|
||||
@@ -3208,7 +3216,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
||||
}
|
||||
|
||||
// mat-vec kernel processing in chunks of float4x4
|
||||
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
|
||||
template<short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
|
||||
void kernel_mul_mv_ext_q4x4_f32_impl(
|
||||
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||
device const char * src0,
|
||||
@@ -3217,6 +3225,9 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
const short nxpsg = FC_mul_mv_nxpsg;
|
||||
|
||||
const short chpt = 1;
|
||||
|
||||
//const short nxpsg = (32);
|
||||
@@ -3225,7 +3236,7 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
|
||||
const short tx = tiisg%nxpsg;
|
||||
const short ty = tiisg/nxpsg;
|
||||
|
||||
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
|
||||
const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
|
||||
const int i11 = tgpig.y*r1ptg;
|
||||
const int i1m = tgpig.z;
|
||||
|
||||
@@ -3322,12 +3333,7 @@ kernel void kernel_mul_mv_ext_q4_f32_disp(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
switch (args.nxpsg) {
|
||||
case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||
case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||
case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||
case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||
}
|
||||
kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
|
||||
@@ -3339,12 +3345,7 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
switch (args.nxpsg) {
|
||||
case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||
case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||
case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||
case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||
}
|
||||
kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
|
||||
@@ -3410,7 +3411,7 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
|
||||
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
|
||||
|
||||
template<typename T0, typename T1, short NR0, short NW, typename args_t>
|
||||
template<typename T0, typename T1, short NR0, typename args_t>
|
||||
void kernel_mul_mv_t_t_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -3422,6 +3423,7 @@ void kernel_mul_mv_t_t_impl(
|
||||
ushort sgitg) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
constexpr short NB = 32;
|
||||
constexpr short NF = 8;
|
||||
|
||||
@@ -3486,10 +3488,10 @@ void kernel_mul_mv_t_t_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
}
|
||||
|
||||
template<typename T0, typename T1, short NR0, short NW>
|
||||
template<typename T0, typename T1, short NR0>
|
||||
kernel void kernel_mul_mv_t_t(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
@@ -3499,20 +3501,20 @@ kernel void kernel_mul_mv_t_t(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_t_t_impl<T0, T1, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_t_t_impl<T0, T1, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t;
|
||||
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F>) mul_mv_t_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F>;
|
||||
#endif
|
||||
|
||||
template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW, typename args_t>
|
||||
template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
|
||||
void kernel_mul_mv_t_t_4_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -3524,6 +3526,7 @@ void kernel_mul_mv_t_t_4_impl(
|
||||
ushort sgitg) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
constexpr short NB = 32;
|
||||
constexpr short NF = 16;
|
||||
constexpr short NF4 = NF/4;
|
||||
@@ -3591,10 +3594,10 @@ void kernel_mul_mv_t_t_4_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
}
|
||||
|
||||
template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW>
|
||||
template<typename T0, typename T04, typename T1, typename T14, short NR0>
|
||||
kernel void kernel_mul_mv_t_t_4(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
@@ -3604,17 +3607,17 @@ kernel void kernel_mul_mv_t_t_4(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t_4;
|
||||
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>) mul_mv_t_t_4;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F>;
|
||||
#endif
|
||||
|
||||
#define N_MV_T_T 4
|
||||
@@ -5966,7 +5969,7 @@ kernel void kernel_concat(
|
||||
}
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_q2_K_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -6068,10 +6071,10 @@ kernel void kernel_mul_mv_q2_K_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_q3_K_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -6233,10 +6236,10 @@ kernel void kernel_mul_mv_q3_K_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_q4_K_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -6248,9 +6251,9 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||
ushort sgitg) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
const uint16_t kmask1 = 0x3f3f;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
const uint16_t kmask3 = 0xc0c0;
|
||||
constexpr uint16_t kmask1 = 0x3f3f;
|
||||
constexpr uint16_t kmask2 = 0x0f0f;
|
||||
constexpr uint16_t kmask3 = 0xc0c0;
|
||||
|
||||
const short ix = tiisg/8; // 0...3
|
||||
const short it = tiisg%8; // 0...7
|
||||
@@ -6309,7 +6312,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||
|
||||
for (short i = 0; i < 4; ++i) {
|
||||
FOR_UNROLL (short i = 0; i < 4; ++i) {
|
||||
acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
|
||||
acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
|
||||
acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
|
||||
@@ -6320,14 +6323,11 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||
acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
|
||||
}
|
||||
|
||||
float dall = dh[0];
|
||||
float dmin = dh[1];
|
||||
|
||||
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
||||
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
||||
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
||||
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
||||
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
||||
sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
||||
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
||||
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
||||
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
||||
dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
||||
|
||||
q1 += args.nb01/2;
|
||||
sc += args.nb01/2;
|
||||
@@ -6357,10 +6357,10 @@ kernel void kernel_mul_mv_q4_K_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_q5_K_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -6393,9 +6393,9 @@ void kernel_mul_mv_q5_K_f32_impl(
|
||||
|
||||
float yl[16], yh[16];
|
||||
|
||||
const uint16_t kmask1 = 0x3f3f;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
const uint16_t kmask3 = 0xc0c0;
|
||||
constexpr uint16_t kmask1 = 0x3f3f;
|
||||
constexpr uint16_t kmask2 = 0x0f0f;
|
||||
constexpr uint16_t kmask3 = 0xc0c0;
|
||||
|
||||
const short tid = tiisg/4;
|
||||
const short ix = tiisg%4;
|
||||
@@ -6441,7 +6441,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
||||
|
||||
float4 acc1 = {0.f};
|
||||
float4 acc2 = {0.f};
|
||||
for (short l = 0; l < 8; ++l) {
|
||||
FOR_UNROLL (short l = 0; l < 8; ++l) {
|
||||
uint8_t h = qh[l];
|
||||
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
||||
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
||||
@@ -6452,13 +6452,12 @@ void kernel_mul_mv_q5_K_f32_impl(
|
||||
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
|
||||
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
|
||||
}
|
||||
const float dall = dh[0];
|
||||
const float dmin = dh[1];
|
||||
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
||||
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
||||
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
||||
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
||||
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
||||
|
||||
sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
||||
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
||||
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
||||
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
||||
dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
||||
|
||||
q1 += args.nb01;
|
||||
qh += args.nb01;
|
||||
@@ -6489,10 +6488,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_q6_K_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -6504,10 +6503,10 @@ void kernel_mul_mv_q6_K_f32_impl(
|
||||
ushort sgitg) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
const uint8_t kmask1 = 0x03;
|
||||
const uint8_t kmask2 = 0x0C;
|
||||
const uint8_t kmask3 = 0x30;
|
||||
const uint8_t kmask4 = 0xC0;
|
||||
constexpr uint8_t kmask1 = 0x03;
|
||||
constexpr uint8_t kmask2 = 0x0C;
|
||||
constexpr uint8_t kmask3 = 0x30;
|
||||
constexpr uint8_t kmask4 = 0xC0;
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
|
||||
@@ -6558,18 +6557,16 @@ void kernel_mul_mv_q6_K_f32_impl(
|
||||
}
|
||||
|
||||
for (short row = 0; row < nr0; ++row) {
|
||||
const float dall = dh[0];
|
||||
|
||||
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
||||
|
||||
for (short l = 0; l < 4; ++l) {
|
||||
FOR_UNROLL (short l = 0; l < 4; ++l) {
|
||||
sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
||||
sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
||||
sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
||||
sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
||||
}
|
||||
|
||||
sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
||||
sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
||||
|
||||
q1 += args.nb01;
|
||||
q2 += args.nb01;
|
||||
@@ -6599,12 +6596,12 @@ kernel void kernel_mul_mv_q6_K_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
// ======================= "True" 2-bit
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -6709,10 +6706,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -6828,10 +6825,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -6940,10 +6937,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_iq3_s_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7052,10 +7049,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_iq2_s_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7165,10 +7162,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_iq1_s_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7264,10 +7261,10 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_iq1_m_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7373,10 +7370,10 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7480,10 +7477,10 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7587,10 +7584,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, int nw, typename args_t>
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_mxfp4_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7677,7 +7674,7 @@ kernel void kernel_mul_mv_mxfp4_f32(
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||
@@ -8353,7 +8350,7 @@ void mmv_fn(
|
||||
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, N_SIMDWIDTH, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
||||
|
||||
template<mul_mv_impl_fn_t impl_fn>
|
||||
kernel void kernel_mul_mv_id(
|
||||
@@ -8418,44 +8415,44 @@ kernel void kernel_mul_mv_id(
|
||||
sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_t;
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>) kernel_mul_mv_id_t;
|
||||
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_4_t;
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>) kernel_mul_mv_id_4_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F>>>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F>>>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F>>>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F>>>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS>>>;
|
||||
|
||||
kernel void kernel_pool_2d_max_f32(
|
||||
constant ggml_metal_kargs_pool_2d & args,
|
||||
|
||||
@@ -83,8 +83,10 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mv_q4_0_f32_1d_16x_flat
|
||||
mul_mv_q6_k
|
||||
mul_mv_mxfp4_f32
|
||||
mul_mv_mxfp4_f32_flat
|
||||
mul_mv_id_q4_0_f32_8x_flat
|
||||
mul_mv_id_mxfp4_f32
|
||||
mul_mv_id_mxfp4_f32_flat
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul
|
||||
|
||||
@@ -368,6 +368,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_program program_mul_mv_q4_0_f32_1d_16x_flat;
|
||||
cl_program program_mul_mv_q6_K;
|
||||
cl_program program_mul_mv_mxfp4_f32;
|
||||
cl_program program_mul_mv_mxfp4_f32_flat;
|
||||
cl_program program_mul_mv_f16_f16;
|
||||
cl_program program_mul_mv_f16_f32_1row;
|
||||
cl_program program_mul_mv_f16_f32_l4;
|
||||
@@ -402,6 +403,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_program program_tsembd;
|
||||
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
|
||||
cl_program program_mul_mv_id_mxfp4_f32;
|
||||
cl_program program_mul_mv_id_mxfp4_f32_flat;
|
||||
cl_program program_mul_mm_f32_f32_l4_lm;
|
||||
cl_program program_mul_mm_f16_f32_l4_lm;
|
||||
|
||||
@@ -447,11 +449,12 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mat_f16_f32_tiled;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
|
||||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
|
||||
cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
|
||||
cl_kernel kernel_convert_block_q4_0_noshuffle;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
|
||||
cl_kernel kernel_mul_mv_q6_K_f32;
|
||||
cl_kernel kernel_mul_mv_mxfp4_f32;
|
||||
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
|
||||
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
|
||||
cl_kernel kernel_argsort_f32_i32;
|
||||
cl_kernel kernel_sum_rows_f32;
|
||||
@@ -469,6 +472,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_timestep_embedding;
|
||||
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
|
||||
cl_kernel kernel_mul_mv_id_mxfp4_f32;
|
||||
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
|
||||
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
|
||||
|
||||
@@ -765,6 +769,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
@@ -1002,6 +1008,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_mxfp4_f32_flat
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_mxfp4_f32_flat.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_mxfp4_f32_flat.cl");
|
||||
#endif
|
||||
backend_ctx->program_mul_mv_mxfp4_f32_flat =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32_flat, "kernel_mul_mv_mxfp4_f32_flat", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_f16_f16
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -1727,6 +1749,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_id_mxfp4_f32_flat
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_id_mxfp4_f32_flat.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_id_mxfp4_f32_flat.cl");
|
||||
#endif
|
||||
backend_ctx->program_mul_mv_id_mxfp4_f32_flat =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32_flat, "kernel_mul_mv_id_mxfp4_f32_flat", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// Adreno kernels
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
// transpose
|
||||
@@ -2391,6 +2429,51 @@ struct ggml_tensor_extra_cl_q4_0 {
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor_extra_cl_mxfp4 {
|
||||
// Quantized values.
|
||||
cl_mem q = nullptr;
|
||||
// Quantized values in image1d_buffer_t.
|
||||
cl_mem q_img = nullptr;
|
||||
// Scales in E8M0.
|
||||
cl_mem e = nullptr;
|
||||
// Scales in image1d_buffer_t.
|
||||
cl_mem e_img = nullptr;
|
||||
// Size of quantized values.
|
||||
size_t size_q = 0;
|
||||
// Size of scales.
|
||||
size_t size_e = 0;
|
||||
|
||||
~ggml_tensor_extra_cl_mxfp4() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
// q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
|
||||
// They must be properly released so that the original buffer can be
|
||||
// properly released to avoid memory leak.
|
||||
if (q != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(q));
|
||||
q = nullptr;
|
||||
}
|
||||
if (e != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(e));
|
||||
e = nullptr;
|
||||
}
|
||||
if (q != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(q_img));
|
||||
q = nullptr;
|
||||
}
|
||||
// Currently, q_img and d_img are only initialized when SMALL_ALLOC is
|
||||
// enabled. They point to the images in ggml_backend_opencl_buffer_context.
|
||||
// So, there is no need to release them here.
|
||||
// TODO: initialize them for non SMALL_PATH path, or remove them.
|
||||
q_img = nullptr;
|
||||
e_img = nullptr;
|
||||
size_q = 0;
|
||||
size_e = 0;
|
||||
}
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Backend API
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -2838,7 +2921,7 @@ static ggml_backend_i ggml_backend_opencl_i = {
|
||||
/* .graph_compute = */ ggml_backend_opencl_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .optimize_graph = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
ggml_backend_t ggml_backend_opencl_init(void) {
|
||||
@@ -2894,6 +2977,12 @@ struct ggml_backend_opencl_buffer_context {
|
||||
for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) {
|
||||
delete e;
|
||||
}
|
||||
for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) {
|
||||
delete e;
|
||||
}
|
||||
for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {
|
||||
delete e;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() {
|
||||
@@ -2926,6 +3015,21 @@ struct ggml_backend_opencl_buffer_context {
|
||||
return extra;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() {
|
||||
ggml_tensor_extra_cl_mxfp4 * extra;
|
||||
if (temp_tensor_extras_mxfp4.empty()) {
|
||||
extra = new ggml_tensor_extra_cl_mxfp4();
|
||||
} else {
|
||||
extra = temp_tensor_extras_mxfp4.back();
|
||||
temp_tensor_extras_mxfp4.pop_back();
|
||||
}
|
||||
|
||||
temp_tensor_extras_mxfp4_in_use.push_back(extra);
|
||||
|
||||
extra->reset();
|
||||
return extra;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {
|
||||
temp_tensor_extras.push_back(e);
|
||||
@@ -2936,6 +3040,11 @@ struct ggml_backend_opencl_buffer_context {
|
||||
temp_tensor_extras_q4_0.push_back(e);
|
||||
}
|
||||
temp_tensor_extras_q4_0_in_use.clear();
|
||||
|
||||
for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {
|
||||
temp_tensor_extras_mxfp4.push_back(e);
|
||||
}
|
||||
temp_tensor_extras_mxfp4_in_use.clear();
|
||||
}
|
||||
|
||||
// Pools for extras. Available extras are in `temp_tensor_extras`. Extras
|
||||
@@ -2947,6 +3056,8 @@ struct ggml_backend_opencl_buffer_context {
|
||||
std::vector<ggml_tensor_extra_cl *> temp_tensor_extras_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0;
|
||||
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4;
|
||||
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;
|
||||
|
||||
// The buffer_context is initially created by ggml_backend_buft_alloc_buffer
|
||||
// before any tensor is initialized (at the beginning of alloc_tensor_range).
|
||||
@@ -3289,6 +3400,76 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
}
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
return;
|
||||
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
|
||||
|
||||
// Allocate the new extra and create aliases from the original.
|
||||
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||
ggml_tensor_extra_cl_mxfp4 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_mxfp4();
|
||||
|
||||
size_t size_e = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(char);
|
||||
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
|
||||
GGML_ASSERT(size_e + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
CL_CHECK(clEnqueueWriteBuffer(
|
||||
queue, data_device, CL_TRUE, 0,
|
||||
ggml_nbytes(tensor), data, 0, NULL, NULL));
|
||||
|
||||
// The original tensor memory is divided into scales and quants, i.e.,
|
||||
// we first store scales, then quants.
|
||||
cl_buffer_region region;
|
||||
|
||||
// Create subbuffer for scales.
|
||||
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
|
||||
region.size = size_e;
|
||||
extra->e = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
auto previous_origin = region.origin;
|
||||
|
||||
// Create subbuffer for quants.
|
||||
region.origin = align_to(previous_origin + size_e, backend_ctx->alignment);
|
||||
region.size = size_q;
|
||||
extra->q = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
|
||||
// Create image for Q
|
||||
cl_image_format img_format_q = {CL_RG, CL_UNSIGNED_INT32};
|
||||
cl_image_desc img_desc_q = {
|
||||
CL_MEM_OBJECT_IMAGE1D_BUFFER,
|
||||
static_cast<size_t>(ggml_nelements(tensor)/32*2),
|
||||
0, 0, 0, 0, 0, 0, 0,
|
||||
{ extra->q }
|
||||
};
|
||||
extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
|
||||
|
||||
tensor->extra = extra;
|
||||
|
||||
return;
|
||||
}
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
@@ -3337,6 +3518,31 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clEnqueueReadBuffer(
|
||||
queue, data_device, CL_TRUE, offset,
|
||||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
} else if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra;
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, &evt));
|
||||
@@ -3658,6 +3864,19 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
|
||||
CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
|
||||
CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL));
|
||||
CL_CHECK(clFinish(queue));
|
||||
} else if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra;
|
||||
GGML_ASSERT(extra);
|
||||
|
||||
size_t size_q = ggml_nelements(tensor)/QK_MXFP4 * QK_MXFP4/2;
|
||||
size_t size_e = ggml_nelements(tensor)/QK_MXFP4 * sizeof(char);
|
||||
GGML_ASSERT(size_q + size_e == ggml_nbytes(tensor));
|
||||
buf_q = malloc(size_q);
|
||||
buf_d = malloc(size_e);
|
||||
|
||||
CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
|
||||
CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL));
|
||||
CL_CHECK(clFinish(queue));
|
||||
} else {
|
||||
// Read out the tensor from GPU memory.
|
||||
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
|
||||
@@ -6048,6 +6267,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
|
||||
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
|
||||
#endif
|
||||
|
||||
const int ne00 = src0 ? src0->ne[0] : 0;
|
||||
@@ -6752,6 +6972,45 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
|
||||
break;
|
||||
case GGML_TYPE_MXFP4: {
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat;
|
||||
|
||||
cl_mem q;
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 16;
|
||||
nth1 = 2;
|
||||
ndst = nth1*2;
|
||||
|
||||
q = extra0_mxfp4->q;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
nth1 = 2;
|
||||
ndst = nth1*2;
|
||||
|
||||
q = extra0_mxfp4->q_img;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_mxfp4->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3));
|
||||
#else
|
||||
kernel = backend_ctx->kernel_mul_mv_mxfp4_f32;
|
||||
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
@@ -6785,6 +7044,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float)*nth0,nullptr));
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -6850,8 +7110,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
cl_ulong offset2 = extra2->offset + src2->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
GGML_UNUSED(offset0);
|
||||
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
|
||||
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
|
||||
#endif
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
@@ -6940,6 +7203,51 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_MXFP4: {
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat;
|
||||
|
||||
cl_mem q;
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
sgs = 16;
|
||||
nsg = 2;
|
||||
ndst = 2;
|
||||
|
||||
q = extra0_mxfp4->q;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
sgs = 64;
|
||||
nsg = 1;
|
||||
ndst = 4;
|
||||
|
||||
q = extra0_mxfp4->q_img;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_mxfp4->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne20));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne21));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb21));
|
||||
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3));
|
||||
#else // GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32;
|
||||
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
@@ -6979,7 +7287,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
||||
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs,nullptr));
|
||||
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -116,3 +116,49 @@ kernel void kernel_convert_block_q4_0_noshuffle(
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_0
|
||||
//------------------------------------------------------------------------------
|
||||
#define QK_MXFP4 32
|
||||
struct block_mxfp4 {
|
||||
uchar e; // E8M0
|
||||
uchar qs[QK_MXFP4 / 2];
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_mxfp4
|
||||
// Convert the block_mxfp4 format to 2 separate arrays (AOS -> SOA).
|
||||
// This kernel does not deshuffle the bits.
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_convert_block_mxfp4(
|
||||
global struct block_mxfp4 * src0,
|
||||
global uchar * dst_q,
|
||||
global uchar * dst_e
|
||||
) {
|
||||
global struct block_mxfp4 * b = (global struct block_mxfp4 *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + QK_MXFP4 / 2 * get_global_id(0);
|
||||
global uchar * e = (global uchar *) dst_e + get_global_id(0);
|
||||
|
||||
*e = b->e;
|
||||
|
||||
for (int i = 0; i < QK_MXFP4 / 2; ++i) {
|
||||
q[i] = b->qs[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_mxfp4(
|
||||
global uchar * src_q,
|
||||
global half * src_e,
|
||||
global struct block_mxfp4 * dst
|
||||
) {
|
||||
global struct block_mxfp4 * b = (global struct block_mxfp4 *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + QK_MXFP4 / 2 * get_global_id(0);
|
||||
global uchar * e = (global uchar *) src_e + get_global_id(0);
|
||||
|
||||
b->e = *e;
|
||||
for (int i = 0; i < QK_MXFP4 / 2; ++i) {
|
||||
b->qs[i] = q[i];
|
||||
}
|
||||
}
|
||||
|
||||
176
ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl
Normal file
176
ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl
Normal file
@@ -0,0 +1,176 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK_MXFP4 32
|
||||
|
||||
static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {
|
||||
ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;
|
||||
fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;
|
||||
fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;
|
||||
fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;
|
||||
bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;
|
||||
bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;
|
||||
bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;
|
||||
|
||||
fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;
|
||||
fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;
|
||||
fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;
|
||||
fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;
|
||||
|
||||
sign_a.lo = (fp4x4 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x4 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x4 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x4 & 0x8000;
|
||||
|
||||
fp16_packed_a = sign_a + bias_a + fp16_packed_a;
|
||||
fp16_packed_b = sign_b + bias_b + fp16_packed_b;
|
||||
|
||||
return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_R0_MXFP4 2 // number of rows each subgroup works on
|
||||
#define N_SG_MXFP4 2 // number of subgroups in a work group
|
||||
#define N_SIMDWIDTH 16 // subgroup size
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_R0_MXFP4 4
|
||||
#define N_SG_MXFP4 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#define SRC0Q_IMG
|
||||
#endif
|
||||
|
||||
kernel void kernel_mul_mv_id_mxfp4_f32_flat(
|
||||
#ifdef SRC0Q_IMG
|
||||
__read_only image1d_buffer_t src0_q,
|
||||
#else
|
||||
global uchar * src0_q,
|
||||
#endif
|
||||
global uchar * src0_e,
|
||||
global uchar * src1,
|
||||
ulong offset1,
|
||||
global uchar * src2,
|
||||
ulong offset2,
|
||||
global uchar * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne11,
|
||||
int ne12,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne20,
|
||||
int ne21,
|
||||
ulong nb21,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
dst = dst + offsetd;
|
||||
|
||||
const int iid1 = get_group_id(2) / ne20;
|
||||
const int idx = get_group_id(2) % ne20;
|
||||
|
||||
uint i02 = ((global uint *) (src2 + offset2 + iid1 * nb21))[idx];
|
||||
|
||||
int i11 = idx % ne11;
|
||||
|
||||
int nb = ne00 / QK_MXFP4;
|
||||
|
||||
uint src0_off = i02*nb02;
|
||||
src0_off /= 17; // 17 = sizeof(block_mxfp4)
|
||||
|
||||
src0_e = src0_e + src0_off;
|
||||
|
||||
dst = dst + (idx * ne0 + iid1 * ne1 * ne0) * sizeof(float);
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
|
||||
int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
|
||||
|
||||
uint offset_src0 = first_row*nb01;
|
||||
offset_src0 /= 17; // 17 = sizeof(block_mxfp4)
|
||||
#ifdef SRC0Q_IMG
|
||||
ulong offset_q = src0_off + offset_src0;
|
||||
#else
|
||||
src0_q = src0_q + src0_off*16;
|
||||
global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0;
|
||||
#endif
|
||||
global uchar * x_e = src0_e + offset_src0;
|
||||
|
||||
const short ix = get_sub_group_local_id() >> 1;
|
||||
const short it = get_sub_group_local_id() & 1;
|
||||
|
||||
float sumf[N_R0_MXFP4] = {0.f};
|
||||
|
||||
src1 = src1 + offset1 + i11 * nb11 + iid1 * nb12;
|
||||
global float * y = (global float *) (src1 + r1 * nb11);
|
||||
global float * yb = y + ix * QK_MXFP4 + it * 8;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH / 2) {
|
||||
global float4 * y4 = (global float4 *)yb;
|
||||
|
||||
#pragma unroll
|
||||
for (short row = 0; row < N_R0_MXFP4; row++) {
|
||||
uchar xb_e = x_e[row * nb + ib];
|
||||
#ifdef SRC0Q_IMG
|
||||
ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy);
|
||||
#else
|
||||
ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it));
|
||||
#endif
|
||||
|
||||
half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0);
|
||||
half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1);
|
||||
float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
|
||||
acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
|
||||
|
||||
fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2);
|
||||
fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3);
|
||||
acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
|
||||
acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
|
||||
|
||||
sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
|
||||
}
|
||||
|
||||
yb += (N_SIMDWIDTH / 2) * QK_MXFP4;
|
||||
}
|
||||
|
||||
global float * dst_f32 = (global float *)dst + (ulong)r1 * ne0;
|
||||
|
||||
for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
|
||||
float sum_all = sub_group_reduce_add(sumf[row]);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst_f32[first_row + row] = sum_all;
|
||||
}
|
||||
}
|
||||
}
|
||||
167
ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl
Normal file
167
ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl
Normal file
@@ -0,0 +1,167 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK_MXFP4 32
|
||||
|
||||
static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) {
|
||||
ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00;
|
||||
fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00;
|
||||
fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00;
|
||||
fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800;
|
||||
bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800;
|
||||
bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800;
|
||||
bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800;
|
||||
|
||||
fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo;
|
||||
fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi;
|
||||
fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo;
|
||||
fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi;
|
||||
|
||||
sign_a.lo = (fp4x4 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x4 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x4 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x4 & 0x8000;
|
||||
|
||||
fp16_packed_a = sign_a + bias_a + fp16_packed_a;
|
||||
fp16_packed_b = sign_b + bias_b + fp16_packed_b;
|
||||
|
||||
return as_half4((ushort4)(fp16_packed_a, fp16_packed_b));
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_R0_MXFP4 2 // number of rows each subgroup works on
|
||||
#define N_SG_MXFP4 2 // number of subgroups in a work group
|
||||
#define N_SIMDWIDTH 16 // subgroup size
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_R0_MXFP4 2
|
||||
#define N_SG_MXFP4 2
|
||||
#define N_SIMDWIDTH 64
|
||||
#define SRC0Q_IMG
|
||||
#endif
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_mxfp4_f32_flat(
|
||||
#ifdef SRC0Q_IMG
|
||||
__read_only image1d_buffer_t src0_q,
|
||||
#else
|
||||
global uchar * src0_q,
|
||||
#endif
|
||||
global uchar * src0_e,
|
||||
global uchar * src1,
|
||||
ulong offset1,
|
||||
global uchar * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne12,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int nb = ne00 / QK_MXFP4;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
|
||||
|
||||
uint i12 = im % ne12;
|
||||
uint i13 = im / ne12;
|
||||
|
||||
uint offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
// 17 = sizeof(block_mxfp4)
|
||||
offset_src0 /= 17;
|
||||
#ifdef SRC0Q_IMG
|
||||
ulong offset_q = offset_src0;
|
||||
#else
|
||||
global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0;
|
||||
#endif
|
||||
global uchar * x_e = src0_e + offset_src0;
|
||||
|
||||
ulong offset_src1 = r1 * nb11 + i12 * nb12 + i13 * nb13;
|
||||
global float * y = (global float *)(src1 + offset_src1);
|
||||
|
||||
const short ix = get_sub_group_local_id() >> 1; // 0...15
|
||||
const short it = get_sub_group_local_id() & 1; // 0 or 1
|
||||
|
||||
float sumf[N_R0_MXFP4] = {0.f};
|
||||
|
||||
global float * yb = y + ix * QK_MXFP4 + it * 8;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
global float4 * y4 = (global float4 *)yb;
|
||||
|
||||
#pragma unroll
|
||||
for (short row = 0; row < N_R0_MXFP4; row++) {
|
||||
uchar xb_e = x_e[row * nb + ib];
|
||||
#ifdef SRC0Q_IMG
|
||||
ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy);
|
||||
#else
|
||||
ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it));
|
||||
#endif
|
||||
|
||||
half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0);
|
||||
half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1);
|
||||
float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
|
||||
acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
|
||||
|
||||
fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2);
|
||||
fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3);
|
||||
acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2);
|
||||
acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3);
|
||||
|
||||
sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
|
||||
}
|
||||
|
||||
yb += (N_SIMDWIDTH/2) * QK_MXFP4;
|
||||
}
|
||||
|
||||
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
|
||||
|
||||
for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
|
||||
float sum_all = sub_group_reduce_add(sumf[row]);
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst_f32[first_row + row] = sum_all;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -795,7 +795,7 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
||||
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .optimize_graph = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
||||
|
||||
@@ -4073,7 +4073,7 @@ static ggml_backend_i ggml_backend_sycl_interface = {
|
||||
/* .graph_compute = */ ggml_backend_sycl_graph_compute,
|
||||
/* .event_record = */ ggml_backend_sycl_event_record,
|
||||
/* .event_wait = */ ggml_backend_sycl_event_wait,
|
||||
/* .optimize_graph = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_sycl_guid() {
|
||||
|
||||
@@ -593,7 +593,7 @@ struct vk_device_struct {
|
||||
bool disable_fusion;
|
||||
bool disable_host_visible_vidmem;
|
||||
bool allow_sysmem_fallback;
|
||||
bool disable_optimize_graph;
|
||||
bool disable_graph_optimize;
|
||||
|
||||
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
||||
std::unique_ptr<vk_memory_logger> memory_logger;
|
||||
@@ -3624,8 +3624,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
|
||||
device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
|
||||
|
||||
const char* GGML_VK_DISABLE_OPTIMIZE_GRAPH = getenv("GGML_VK_DISABLE_OPTIMIZE_GRAPH");
|
||||
device->disable_optimize_graph = GGML_VK_DISABLE_OPTIMIZE_GRAPH != nullptr;
|
||||
const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv("GGML_VK_DISABLE_GRAPH_OPTIMIZE");
|
||||
device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr;
|
||||
|
||||
bool fp16_storage = false;
|
||||
bool fp16_compute = false;
|
||||
@@ -11914,12 +11914,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
}
|
||||
|
||||
// Sort the graph for improved parallelism.
|
||||
static void ggml_vk_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph)
|
||||
static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph)
|
||||
{
|
||||
VK_LOG_DEBUG("ggml_vk_optimize_graph(" << graph->n_nodes << " nodes)");
|
||||
VK_LOG_DEBUG("ggml_vk_graph_optimize(" << graph->n_nodes << " nodes)");
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
|
||||
if (ctx->device->disable_optimize_graph) {
|
||||
if (ctx->device->disable_graph_optimize) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -12053,7 +12053,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
|
||||
/* .graph_compute = */ ggml_backend_vk_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .optimize_graph = */ ggml_vk_optimize_graph,
|
||||
/* .graph_optimize = */ ggml_vk_graph_optimize,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_vk_guid() {
|
||||
|
||||
@@ -823,7 +823,7 @@ static ggml_backend_i ggml_backend_webgpu_i = {
|
||||
/* .graph_compute = */ ggml_backend_webgpu_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .optimize_graph = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
/* End GGML Backend Interface */
|
||||
|
||||
@@ -574,7 +574,7 @@ static ggml_backend_i ggml_backend_zdnn_i = {
|
||||
/* .graph_compute = */ ggml_backend_zdnn_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .optimize_graph = */ NULL,
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_zdnn_guid(void) {
|
||||
|
||||
@@ -6507,6 +6507,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_pad());
|
||||
test_cases.emplace_back(new test_pad_ext());
|
||||
test_cases.emplace_back(new test_pad_reflect_1d());
|
||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
||||
test_cases.emplace_back(new test_roll());
|
||||
test_cases.emplace_back(new test_arange());
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
@@ -6645,6 +6646,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {512, 34, 2, 1}));
|
||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 1, 1}));
|
||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 4, 1}));
|
||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));
|
||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user