Compare commits

...

10 Commits
b7678 ... b7688

Author SHA1 Message Date
shaofeiqi
593da7fa49 opencl: add EXPM1 op (#18704) 2026-01-09 10:13:13 -08:00
Reese Levine
9e41884dce Updates to webgpu get_memory (#18707) 2026-01-09 08:17:18 -08:00
Pascal
ec8fd7876b Webui/file upload (#18694)
* webui: fix restrictive file type validation

* webui: simplify file processing logic

* chore: update webui build output

* webui: remove file picker extension whitelist (1/2)

* webui: remove file picker extension whitelist (2/2)

* chore: update webui build output

* refactor: Cleanup

* chore: update webui build output

* fix: update ChatForm storybook test after removing accept attribute

* chore: update webui build output

* refactor: more cleanup

* chore: update webui build output
2026-01-09 16:45:32 +01:00
Asbjørn Olling
a180ba78c7 cmake: only build cli when server is enabled (#18670) 2026-01-09 16:43:26 +01:00
Georgi Gerganov
53eb9435da server : fix timing of prompt/generation (#18713) 2026-01-09 12:59:50 +02:00
Georgi Gerganov
d3435efc8a scripts : pr2wt.sh reset to remote head (#18695)
* scripts : pr2wt.sh reset to remote head

* cont : cleaner

* cont : restore --set-upstream-to
2026-01-09 12:16:40 +02:00
Georgi Gerganov
f5f8812f7c server : use different seeds for child completions (#18700)
* server : use different seeds for child completions

* cont : handle default seed

* cont : note
2026-01-09 09:33:50 +02:00
Xuan-Son Nguyen
8ece3836b4 common: support remote preset (#18520)
* arg: support remote preset

* proof reading

* allow one HF repo to point to multiple HF repos

* docs: mention about multiple GGUF use case

* correct clean_file_name

* download: also return HTTP status code

* fix case with cache file used

* fix --offline option
2026-01-08 22:35:40 +01:00
Aaron Teo
046d5fd44e llama: use host memory if device reports 0 memory (#18587) 2026-01-09 05:34:56 +08:00
Masashi Yoshimura
480160d472 ggml-webgpu: Fix GGML_MEM_ALIGN to 8 for emscripten. (#18628)
* Fix GGML_MEM_ALIGN to 8 for emscripten.

* Add a comment explaining the need for GGML_MEM_ALIGN == 8 in 64-bit wasm with emscripten
2026-01-08 08:36:42 -08:00
30 changed files with 672 additions and 292 deletions

View File

@@ -6,6 +6,7 @@
#include "log.h"
#include "sampling.h"
#include "download.h"
#include "preset.h"
// fix problem with std::min and std::max
#if defined(_WIN32)
@@ -268,6 +269,46 @@ static void parse_tensor_buffer_overrides(const std::string & value, std::vector
}
}
static std::string clean_file_name(const std::string & fname) {
std::string clean_fname = fname;
string_replace_all(clean_fname, "\\", "_");
string_replace_all(clean_fname, "/", "_");
return clean_fname;
}
static bool common_params_handle_remote_preset(common_params & params, llama_example ex) {
GGML_ASSERT(!params.model.hf_repo.empty());
const bool offline = params.offline;
std::string model_endpoint = get_model_endpoint();
auto preset_url = model_endpoint + params.model.hf_repo + "/resolve/main/preset.ini";
// prepare local path for caching
auto preset_fname = clean_file_name(params.model.hf_repo + "_preset.ini");
auto preset_path = fs_get_cache_file(preset_fname);
const int status = common_download_file_single(preset_url, preset_path, params.hf_token, offline);
const bool has_preset = status >= 200 && status < 400;
// remote preset is optional, so we don't error out if not found
if (has_preset) {
LOG_INF("applying remote preset from %s\n", preset_url.c_str());
common_preset_context ctx(ex, /* only_remote_allowed */ true);
common_preset global; // unused for now
auto remote_presets = ctx.load_from_ini(preset_path, global);
if (remote_presets.find(COMMON_PRESET_DEFAULT_NAME) != remote_presets.end()) {
common_preset & preset = remote_presets.at(COMMON_PRESET_DEFAULT_NAME);
LOG_INF("\n%s", preset.to_ini().c_str()); // to_ini already added trailing newline
preset.apply_to_params(params);
} else {
throw std::runtime_error("Remote preset.ini does not contain [" + std::string(COMMON_PRESET_DEFAULT_NAME) + "] section");
}
} else {
LOG_INF("%s", "no remote preset found, skipping\n");
}
return has_preset;
}
struct handle_model_result {
bool found_mmproj = false;
common_params_model mmproj;
@@ -309,9 +350,7 @@ static handle_model_result common_params_handle_model(
// make sure model path is present (for caching purposes)
if (model.path.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = model.hf_repo + "_" + model.hf_file;
// to make sure we don't have any slashes in the filename
string_replace_all(filename, "/", "_");
std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file);
model.path = fs_get_cache_file(filename);
}
@@ -425,61 +464,87 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}
};
std::set<std::string> seen_args;
auto parse_cli_args = [&]() {
std::set<std::string> seen_args;
for (int i = 1; i < argc; i++) {
const std::string arg_prefix = "--";
for (int i = 1; i < argc; i++) {
const std::string arg_prefix = "--";
std::string arg = argv[i];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
if (arg_to_options.find(arg) == arg_to_options.end()) {
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
auto & tmp = arg_to_options[arg];
auto opt = *tmp.first;
bool is_positive = tmp.second;
if (opt.has_value_from_env()) {
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str());
}
try {
if (opt.handler_void) {
opt.handler_void(params);
continue;
std::string arg = argv[i];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
if (opt.handler_bool) {
opt.handler_bool(params, is_positive);
continue;
if (arg_to_options.find(arg) == arg_to_options.end()) {
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
auto & tmp = arg_to_options[arg];
auto opt = *tmp.first;
bool is_positive = tmp.second;
if (opt.has_value_from_env()) {
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str());
}
try {
if (opt.handler_void) {
opt.handler_void(params);
continue;
}
if (opt.handler_bool) {
opt.handler_bool(params, is_positive);
continue;
}
// arg with single value
check_arg(i);
std::string val = argv[++i];
if (opt.handler_int) {
opt.handler_int(params, std::stoi(val));
continue;
}
if (opt.handler_string) {
opt.handler_string(params, val);
continue;
}
// arg with single value
check_arg(i);
std::string val = argv[++i];
if (opt.handler_int) {
opt.handler_int(params, std::stoi(val));
continue;
}
if (opt.handler_string) {
opt.handler_string(params, val);
continue;
}
// arg with 2 values
check_arg(i);
std::string val2 = argv[++i];
if (opt.handler_str_str) {
opt.handler_str_str(params, val, val2);
continue;
// arg with 2 values
check_arg(i);
std::string val2 = argv[++i];
if (opt.handler_str_str) {
opt.handler_str_str(params, val, val2);
continue;
}
} catch (std::exception & e) {
throw std::invalid_argument(string_format(
"error while handling argument \"%s\": %s\n\n"
"usage:\n%s\n\nto show complete usage, run with -h",
arg.c_str(), e.what(), opt.to_string().c_str()));
}
} catch (std::exception & e) {
throw std::invalid_argument(string_format(
"error while handling argument \"%s\": %s\n\n"
"usage:\n%s\n\nto show complete usage, run with -h",
arg.c_str(), e.what(), opt.to_string().c_str()));
}
};
// parse the first time to get -hf option (used for remote preset)
parse_cli_args();
// maybe handle remote preset
if (!params.model.hf_repo.empty()) {
std::string cli_hf_repo = params.model.hf_repo;
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
// special case: if hf_repo explicitly set by preset, we need to preserve it (ignore CLI value)
// this is useful when we have one HF repo pointing to other HF repos (one model - multiple GGUFs)
std::string preset_hf_repo = params.model.hf_repo;
bool preset_has_hf_repo = preset_hf_repo != cli_hf_repo;
if (has_preset) {
// re-parse CLI args to override preset values
parse_cli_args();
}
// preserve hf_repo from preset if needed
if (preset_has_hf_repo) {
params.model.hf_repo = preset_hf_repo;
}
}

View File

@@ -157,6 +157,10 @@ static std::string read_etag(const std::string & path) {
return none;
}
static bool is_http_status_ok(int status) {
return status >= 200 && status < 400;
}
#ifdef LLAMA_USE_CURL
//
@@ -306,12 +310,14 @@ static bool common_download_head(CURL * curl,
}
// download one single file from remote URL to local path
static bool common_download_file_single_online(const std::string & url,
// returns status code or -1 on error
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token,
const common_header_list & custom_headers) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
for (int i = 0; i < max_attempts; ++i) {
std::string etag;
@@ -371,7 +377,7 @@ static bool common_download_file_single_online(const std::string & url,
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
return -1;
}
}
@@ -380,14 +386,14 @@ static bool common_download_file_single_online(const std::string & url,
if (std::filesystem::exists(path_temporary)) {
if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
return false;
return -1;
}
}
if (std::filesystem::exists(path)) {
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
return -1;
}
}
}
@@ -414,23 +420,27 @@ static bool common_download_file_single_online(const std::string & url,
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
int status = static_cast<int>(http_code);
if (!is_http_status_ok(http_code)) {
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
return status; // TODO: maybe only return on certain codes
}
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
return -1;
}
return static_cast<int>(http_code);
} else {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
}
break;
return 304; // Not Modified - fake cached response
}
}
return true;
return -1; // max attempts reached
}
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
@@ -625,7 +635,8 @@ static bool common_pull_file(httplib::Client & cli,
}
// download one single file from remote URL to local path
static bool common_download_file_single_online(const std::string & url,
// returns status code or -1 on error
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token,
const common_header_list & custom_headers) {
@@ -659,8 +670,10 @@ static bool common_download_file_single_online(const std::string & url,
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
return true;
return 304; // 304 Not Modified - fake cached response
}
return head->status; // cannot use cached file, return raw status code
// TODO: maybe retry only on certain codes
}
std::string etag;
@@ -692,12 +705,12 @@ static bool common_download_file_single_online(const std::string & url,
if (file_exists) {
if (!should_download_from_scratch) {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
return true;
return 304; // 304 Not Modified - fake cached response
}
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
return -1;
}
}
@@ -709,7 +722,7 @@ static bool common_download_file_single_online(const std::string & url,
existing_size = std::filesystem::file_size(path_temporary);
} else if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
return false;
return -1;
}
}
@@ -730,15 +743,16 @@ static bool common_download_file_single_online(const std::string & url,
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
return -1;
}
if (!etag.empty()) {
write_etag(path, etag);
}
break;
return head->status; // TODO: use actual GET status?
}
return true;
return -1; // max attempts reached
}
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
@@ -777,22 +791,22 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
#if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)
static bool common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
int common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
if (!offline) {
return common_download_file_single_online(url, path, bearer_token, headers);
}
if (!std::filesystem::exists(path)) {
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
return false;
return -1;
}
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
return true;
return 304; // Not Modified - fake cached response
}
// download multiple files from remote URLs to local paths
@@ -810,7 +824,8 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
std::async(
std::launch::async,
[&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
return common_download_file_single(it.first, it.second, bearer_token, offline, headers);
const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
return is_http_status_ok(http_status);
},
item
)
@@ -837,7 +852,8 @@ bool common_download_model(const common_params_model & model,
return false;
}
if (!common_download_file_single(model.url, model.path, bearer_token, offline, headers)) {
const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
if (!is_http_status_ok(http_status)) {
return false;
}
@@ -975,7 +991,7 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
} else {
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
}
// check response
@@ -1094,7 +1110,8 @@ std::string common_docker_resolve_model(const std::string & docker) {
std::string local_path = fs_get_cache_file(model_filename);
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
if (!common_download_file_single(blob_url, local_path, token, false, {})) {
const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
if (!is_http_status_ok(http_status)) {
throw std::runtime_error("Failed to download Docker Model");
}
@@ -1120,6 +1137,14 @@ std::string common_docker_resolve_model(const std::string &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
int common_download_file_single(const std::string &,
const std::string &,
const std::string &,
bool,
const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
#endif // LLAMA_USE_CURL || LLAMA_USE_HTTPLIB
std::vector<common_cached_model_info> common_list_cached_models() {

View File

@@ -65,6 +65,14 @@ bool common_download_model(
// returns list of cached models
std::vector<common_cached_model_info> common_list_cached_models();
// download single file from url to local path
// returns status code or -1 on error
int common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline,
const common_header_list & headers = {});
// resolve and download model from Docker registry
// return local path to downloaded model file
std::string common_docker_resolve_model(const std::string & docker);

View File

@@ -16,6 +16,46 @@ static std::string rm_leading_dashes(const std::string & str) {
return str.substr(pos);
}
// only allow a subset of args for remote presets for security reasons
// do not add more args unless absolutely necessary
// args that output to files are strictly prohibited
static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
static const std::set<std::string> allowed_options = {
"model-url",
"hf-repo",
"hf-repo-draft",
"hf-repo-v", // vocoder
"hf-file-v", // vocoder
"mmproj-url",
"pooling",
"jinja",
"batch-size",
"ubatch-size",
"cache-reuse",
// note: sampling params are automatically allowed by default
// negated args will be added automatically
};
std::set<std::string> allowed_keys;
for (const auto & it : key_to_opt) {
const std::string & key = it.first;
const common_arg & opt = it.second;
if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
allowed_keys.insert(key);
// also add variant keys (args without leading dashes and env vars)
for (const auto & arg : opt.get_args()) {
allowed_keys.insert(rm_leading_dashes(arg));
}
for (const auto & env : opt.get_env()) {
allowed_keys.insert(env);
}
}
}
return allowed_keys;
}
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
std::vector<std::string> args;
@@ -121,6 +161,29 @@ void common_preset::merge(const common_preset & other) {
}
}
void common_preset::apply_to_params(common_params & params) const {
for (const auto & [opt, val] : options) {
// apply each option to params
if (opt.handler_string) {
opt.handler_string(params, val);
} else if (opt.handler_int) {
opt.handler_int(params, std::stoi(val));
} else if (opt.handler_bool) {
opt.handler_bool(params, common_arg_utils::is_truthy(val));
} else if (opt.handler_str_str) {
// not supported yet
throw std::runtime_error(string_format(
"%s: option with two values is not supported yet",
__func__
));
} else if (opt.handler_void) {
opt.handler_void(params);
} else {
GGML_ABORT("unknown handler type");
}
}
}
static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
std::map<std::string, std::map<std::string, std::string>> parsed;
@@ -230,10 +293,16 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
return value;
}
common_preset_context::common_preset_context(llama_example ex)
common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
: ctx_params(common_params_parser_init(default_params, ex)) {
common_params_add_preset_options(ctx_params.options);
key_to_opt = get_map_key_opt(ctx_params);
// setup allowed keys if only_remote_allowed is true
if (only_remote_allowed) {
filter_allowed_keys = true;
allowed_keys = get_remote_preset_whitelist(key_to_opt);
}
}
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
@@ -250,6 +319,12 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
LOG_DBG("loading preset: %s\n", preset.name.c_str());
for (const auto & [key, value] : section.second) {
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) {
throw std::runtime_error(string_format(
"option '%s' is not allowed in remote presets",
key.c_str()
));
}
if (key_to_opt.find(key) != key_to_opt.end()) {
const auto & opt = key_to_opt.at(key);
if (is_bool_arg(opt)) {

View File

@@ -6,6 +6,7 @@
#include <string>
#include <vector>
#include <map>
#include <set>
//
// INI preset parser and writer
@@ -40,6 +41,9 @@ struct common_preset {
// merge another preset into this one, overwriting existing options
void merge(const common_preset & other);
// apply preset options to common_params
void apply_to_params(common_params & params) const;
};
// interface for multiple presets in one file
@@ -50,7 +54,12 @@ struct common_preset_context {
common_params default_params; // unused for now
common_params_context ctx_params;
std::map<std::string, common_arg> key_to_opt;
common_preset_context(llama_example ex);
bool filter_allowed_keys = false;
std::set<std::string> allowed_keys;
// if only_remote_allowed is true, only accept whitelisted keys
common_preset_context(llama_example ex, bool only_remote_allowed = false);
// load presets from INI file
common_presets load_from_ini(const std::string & path, common_preset & global) const;

60
docs/preset.md Normal file
View File

@@ -0,0 +1,60 @@
# llama.cpp INI Presets
## Introduction
The INI preset feature, introduced in [PR#17859](https://github.com/ggml-org/llama.cpp/pull/17859), allows users to create reusable and shareable parameter configurations for llama.cpp.
### Using Presets with the Server
When running multiple models on the server (router mode), INI preset files can be used to configure model-specific parameters. Please refer to the [server documentation](../tools/server/README.md) for more details.
### Using a Remote Preset
> [!NOTE]
>
> This feature is currently only supported via the `-hf` option.
For GGUF models hosted on Hugging Face, you can include a `preset.ini` file in the root directory of the repository to define specific configurations for that model.
Example:
```ini
hf-repo-draft = username/my-draft-model-GGUF
temp = 0.5
top-k = 20
top-p = 0.95
```
For security reasons, only certain options are allowed. Please refer to [preset.cpp](../common/preset.cpp) for the complete list of permitted options.
Example usage:
Assuming your repository `username/my-model-with-preset` contains a `preset.ini` with the configuration above:
```sh
llama-cli -hf username/my-model-with-preset
# This is equivalent to:
llama-cli -hf username/my-model-with-preset \
--hf-repo-draft username/my-draft-model-GGUF \
--temp 0.5 \
--top-k 20 \
--top-p 0.95
```
You can also override preset arguments by specifying them on the command line:
```sh
# Force temp = 0.1, overriding the preset value
llama-cli -hf username/my-model-with-preset --temp 0.1
```
If you want to define multiple preset configurations for one or more GGUF models, you can create a blank HF repo for each preset. Each HF repo should contain a `preset.ini` file that references the actual model(s):
```ini
hf-repo = user/my-model-main
hf-repo-draft = user/my-model-draft
temp = 0.8
ctx-size = 1024
; (and other configurations)
```

View File

@@ -234,6 +234,11 @@
#if UINTPTR_MAX == 0xFFFFFFFF
#define GGML_MEM_ALIGN 4
#elif defined(__EMSCRIPTEN__)
// emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm.
// (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.)
// ref: https://github.com/ggml-org/llama.cpp/pull/18628
#define GGML_MEM_ALIGN 8
#else
#define GGML_MEM_ALIGN 16
#endif

View File

@@ -144,7 +144,7 @@ extern "C" {
// device description: short informative description of the device, could be the model name
const char * (*get_description)(ggml_backend_dev_t dev);
// device memory in bytes
// device memory in bytes: 0 bytes to indicate no memory to report
void (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total);
// device type

View File

@@ -121,6 +121,7 @@ set(GGML_OPENCL_KERNELS
tsembd
upscale
tanh
expm1
pad
repeat
mul_mat_f16_f32

View File

@@ -538,6 +538,8 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_pad;
cl_kernel kernel_tanh_f32_nd;
cl_kernel kernel_tanh_f16_nd;
cl_kernel kernel_expm1_f32_nd;
cl_kernel kernel_expm1_f16_nd;
cl_kernel kernel_upscale;
cl_kernel kernel_upscale_bilinear;
cl_kernel kernel_concat_f32_contiguous;
@@ -1799,6 +1801,31 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
}
}
// expm1
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "expm1.cl.h"
};
#else
const std::string kernel_src = read_file("expm1.cl");
#endif
cl_program prog;
if (!kernel_src.empty()) {
prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_expm1_f32_nd = clCreateKernel(prog, "kernel_expm1_f32_nd", &err), err));
CL_CHECK((backend_ctx->kernel_expm1_f16_nd = clCreateKernel(prog, "kernel_expm1_f16_nd", &err), err));
GGML_LOG_CONT(".");
} else {
GGML_LOG_WARN("ggml_opencl: expm1 kernel source not found or empty. Expm1 operation will not be available.\n");
prog = nullptr;
backend_ctx->kernel_expm1_f32_nd = nullptr;
backend_ctx->kernel_expm1_f16_nd = nullptr;
}
CL_CHECK(clReleaseProgram(prog));
}
// upscale
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3108,6 +3135,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_UNARY_OP_TANH:
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
case GGML_UNARY_OP_EXPM1:
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
default:
return false;
}
@@ -4287,8 +4317,8 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_
}
static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
*free = 1;
*total = 1;
*free = 0;
*total = 0;
GGML_UNUSED(dev);
}
@@ -6464,6 +6494,108 @@ static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0_abs = extra0->offset + src0->view_offs;
cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
cl_kernel kernel;
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_expm1_f32_nd;
} else if (dst->type == GGML_TYPE_F16) {
kernel = backend_ctx->kernel_expm1_f16_nd;
} else {
GGML_ASSERT(false && "Unsupported type for ggml_cl_expm1");
}
GGML_ASSERT(kernel != nullptr);
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne10 = dst->ne[0];
const int ne11 = dst->ne[1];
const int ne12 = dst->ne[2];
const int ne13 = dst->ne[3];
const cl_ulong nb10 = dst->nb[0];
const cl_ulong nb11 = dst->nb[1];
const cl_ulong nb12 = dst->nb[2];
const cl_ulong nb13 = dst->nb[3];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
size_t global_work_size[3];
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
return;
}
global_work_size[0] = (size_t)ne10;
global_work_size[1] = (size_t)ne11;
global_work_size[2] = (size_t)ne12;
size_t lws0 = 16, lws1 = 4, lws2 = 1;
if (ne10 < 16) lws0 = ne10;
if (ne11 < 4) lws1 = ne11;
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
size_t local_work_size[] = {lws0, lws1, lws2};
size_t* local_work_size_ptr = local_work_size;
if (!backend_ctx->non_uniform_workgroups) {
if (global_work_size[0] % local_work_size[0] != 0 ||
global_work_size[1] % local_work_size[1] != 0 ||
global_work_size[2] % local_work_size[2] != 0) {
local_work_size_ptr = NULL;
}
}
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -9637,6 +9769,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_tanh;
break;
case GGML_UNARY_OP_EXPM1:
if (!any_on_device) {
return false;
}
func = ggml_cl_expm1;
break;
default:
return false;
} break;

View File

@@ -0,0 +1,82 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// expm1
//------------------------------------------------------------------------------
kernel void kernel_expm1_f32_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
}
}
kernel void kernel_expm1_f16_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
}
}

View File

@@ -19,6 +19,7 @@
#include <atomic>
#include <condition_variable>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <map>
@@ -1880,9 +1881,18 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_
static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
// TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
*free = ctx->webgpu_ctx->limits.maxBufferSize;
*total = ctx->webgpu_ctx->limits.maxBufferSize;
// TODO: for now, return maxBufferSize as both free and total memory
// Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize;
// If we're on a 32-bit system, clamp to UINTPTR_MAX
#if UINTPTR_MAX < UINT64_MAX
uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
if (max_buffer_size > max_ptr_size) {
max_buffer_size = max_ptr_size;
}
#endif
*free = static_cast<size_t>(max_buffer_size);
*total = static_cast<size_t>(max_buffer_size);
}
static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {

View File

@@ -1292,7 +1292,9 @@ extern "C" {
// available samplers:
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
/// seed == LLAMA_DEFAULT_SEED to use a random seed.
LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// Setting k <= 0 makes this a noop

View File

@@ -4,12 +4,13 @@
#
# - creates a new remote using the fork's clone URL
# - creates a local branch tracking the remote branch
# - creates a new worktree in a parent folder, suffixed with "-pr-${PR}"
# - creates a new worktree in a parent folder, suffixed with "-pr-$PR"
#
# sample usage:
# ./scripts/pr2wt.sh 12345
# ./scripts/pr2wt.sh 12345 opencode
# ./scripts/pr2wt.sh 12345 "cmake -B build && cmake --build build"
# ./scripts/pr2wt.sh 12345 "bash -l"
function usage() {
echo "usage: $0 <pr_number> [cmd]"
@@ -39,7 +40,7 @@ org_repo=${org_repo%.git}
echo "org/repo: $org_repo"
meta=$(curl -sSf -H "Accept: application/vnd.github+json" "https://api.github.com/repos/${org_repo}/pulls/${PR}")
meta=$(curl -sSf -H "Accept: application/vnd.github+json" "https://api.github.com/repos/$org_repo/pulls/$PR")
url_remote=$(echo "$meta" | jq -r '.head.repo.clone_url')
head_ref=$(echo "$meta" | jq -r '.head.ref')
@@ -47,21 +48,32 @@ head_ref=$(echo "$meta" | jq -r '.head.ref')
echo "url: $url_remote"
echo "head_ref: $head_ref"
git remote rm pr/${PR} 2> /dev/null
git remote add pr/${PR} $url_remote
git fetch pr/${PR} $head_ref
url_remote_cur=$(git config --get "remote.pr/$PR.url" 2>/dev/null || true)
if [[ "$url_remote_cur" != "$url_remote" ]]; then
git remote rm pr/$PR 2> /dev/null
git remote add pr/$PR "$url_remote"
fi
git fetch "pr/$PR" "$head_ref"
dir=$(basename $(pwd))
git branch -D pr/$PR 2> /dev/null
git worktree add -b pr/$PR ../$dir-pr-$PR pr/$PR/${head_ref} 2> /dev/null
git worktree add -b pr/$PR ../$dir-pr-$PR pr/$PR/$head_ref 2> /dev/null
wt_path=$(cd ../$dir-pr-$PR && pwd)
echo "git worktree created in $wt_path"
# if a command was provided, execute it
cd $wt_path
git branch --set-upstream-to=pr/$PR/$head_ref
git pull --ff-only || {
echo "error: failed to pull pr/$PR"
exit 1
}
if [[ $# -eq 2 ]]; then
cd ../$dir-pr-$PR
echo "executing: $2"
eval "$2"
fi

View File

@@ -2452,6 +2452,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
pimpl->gpu_buft_list.emplace(dev, std::move(buft_list));
}
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (cpu_dev == nullptr) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
// calculate the split points
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; });
std::vector<float> splits(n_devices());
@@ -2462,6 +2467,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
size_t total;
size_t free;
ggml_backend_dev_memory(dev, &free, &total);
// devices can return 0 bytes for free and total memory if they do not
// have any to report. in this case, we will use the host memory as a fallback
// fixes: https://github.com/ggml-org/llama.cpp/issues/18577
if (free == 0 && total == 0) {
ggml_backend_dev_memory(cpu_dev, &free, &total);
}
splits[i] = free;
}
} else {
@@ -2478,10 +2490,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
splits[i] /= split_sum;
}
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (cpu_dev == nullptr) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0);
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1);
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {

View File

@@ -2142,7 +2142,7 @@ struct llama_sampler_xtc {
const uint32_t seed;
uint32_t seed_cur;
std::mt19937 rng;
std::mt19937 rng;
};
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {

View File

@@ -111,8 +111,20 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
}
}
for (size_t i = 0; i < ret.size(); i++) {
size_t free, total;
size_t free;
size_t total;
ggml_backend_dev_memory(model->devices[i], &free, &total);
// devices can return 0 bytes for free and total memory if they do not
// have any to report. in this case, we will use the host memory as a fallback
// fixes: https://github.com/ggml-org/llama.cpp/issues/18577
if (free == 0 && total == 0) {
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (cpu_dev == nullptr) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
ggml_backend_dev_memory(cpu_dev, &free, &total);
}
ret[i].free = free;
ret[i].total = total;
}

View File

@@ -18,11 +18,11 @@ else()
add_subdirectory(gguf-split)
add_subdirectory(imatrix)
add_subdirectory(llama-bench)
add_subdirectory(cli)
add_subdirectory(completion)
add_subdirectory(perplexity)
add_subdirectory(quantize)
if (LLAMA_BUILD_SERVER)
add_subdirectory(cli)
add_subdirectory(server)
endif()
add_subdirectory(tokenize)

Binary file not shown.

View File

@@ -4,7 +4,6 @@
#include "server-task.h"
#include "server-queue.h"
#include "arg.h"
#include "common.h"
#include "llama.h"
#include "log.h"
@@ -16,7 +15,6 @@
#include <cstddef>
#include <cinttypes>
#include <memory>
#include <unordered_set>
#include <filesystem>
// fix problem with std::min and std::max
@@ -2617,10 +2615,6 @@ private:
// on successful decode, restore the original batch size
n_batch = llama_n_batch(ctx);
// technically, measuring the time here excludes the sampling time for the last batch
// but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
const int64_t t_current = ggml_time_us();
for (auto & slot : slots) {
// may need to copy state to other slots
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
@@ -2687,6 +2681,9 @@ private:
common_sampler_accept(slot.smpl.get(), id, true);
// here we have synchronized the llama_context (due to the sampling above), so we can do time measurement
const int64_t t_current = ggml_time_us();
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
@@ -2730,6 +2727,8 @@ private:
slot.i_batch_dft.clear();
slot.drafted.clear();
const int64_t t_current = ggml_time_us();
slot.n_decoded += ids.size();
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
@@ -2927,9 +2926,14 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
if (task.params.n_cmpl > 1) {
task.n_children = task.params.n_cmpl - 1;
for (size_t j = 0; j < task.n_children; j++) {
server_task child = task.create_child(
task.id,
rd.get_new_id());
server_task child = task.create_child(task.id, rd.get_new_id());
// use different sampling seed for each child
// note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
if (child.params.sampling.seed != LLAMA_DEFAULT_SEED) {
child.params.sampling.seed += j + 1;
}
tasks.push_back(std::move(child));
}
}

View File

@@ -503,5 +503,4 @@ def test_chat_completions_multiple_choices():
assert len(res.body["choices"]) == 2
for choice in res.body["choices"]:
assert "assistant" == choice["message"]["role"]
assert match_regex("Suddenly", choice["message"]["content"])
assert choice["finish_reason"] == "length"

View File

@@ -10,21 +10,11 @@
import { INPUT_CLASSES } from '$lib/constants/input-classes';
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
import { config } from '$lib/stores/settings.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { activeMessages } from '$lib/stores/conversations.svelte';
import {
FileTypeCategory,
MimeTypeApplication,
FileExtensionAudio,
FileExtensionImage,
FileExtensionPdf,
FileExtensionText,
MimeTypeAudio,
MimeTypeImage,
MimeTypeText
} from '$lib/enums';
import { MimeTypeText } from '$lib/enums';
import { isIMEComposing, parseClipboardContent } from '$lib/utils';
import {
AudioRecorder,
@@ -61,7 +51,6 @@
let audioRecorder: AudioRecorder | undefined;
let chatFormActionsRef: ChatFormActions | undefined = $state(undefined);
let currentConfig = $derived(config());
let fileAcceptString = $state<string | undefined>(undefined);
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
let isRecording = $state(false);
let message = $state('');
@@ -104,40 +93,6 @@
return null;
});
// State for model props reactivity
let modelPropsVersion = $state(0);
// Fetch model props when active model changes (works for both MODEL and ROUTER mode)
$effect(() => {
if (activeModelId) {
const cached = modelsStore.getModelProps(activeModelId);
if (!cached) {
modelsStore.fetchModelProps(activeModelId).then(() => {
modelPropsVersion++;
});
}
}
});
// Derive modalities from active model (works for both MODEL and ROUTER mode)
let hasAudioModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion; // Trigger reactivity on props fetch
return modelsStore.modelSupportsAudio(activeModelId);
}
return false;
});
let hasVisionModality = $derived.by(() => {
if (activeModelId) {
void modelPropsVersion; // Trigger reactivity on props fetch
return modelsStore.modelSupportsVision(activeModelId);
}
return false;
});
function checkModelSelected(): boolean {
if (!hasModelSelected) {
// Open the model selector
@@ -148,42 +103,12 @@
return true;
}
function getAcceptStringForFileType(fileType: FileTypeCategory): string {
switch (fileType) {
case FileTypeCategory.IMAGE:
return [...Object.values(FileExtensionImage), ...Object.values(MimeTypeImage)].join(',');
case FileTypeCategory.AUDIO:
return [...Object.values(FileExtensionAudio), ...Object.values(MimeTypeAudio)].join(',');
case FileTypeCategory.PDF:
return [...Object.values(FileExtensionPdf), ...Object.values(MimeTypeApplication)].join(
','
);
case FileTypeCategory.TEXT:
return [...Object.values(FileExtensionText), MimeTypeText.PLAIN].join(',');
default:
return '';
}
}
function handleFileSelect(files: File[]) {
onFileUpload?.(files);
}
function handleFileUpload(fileType?: FileTypeCategory) {
if (fileType) {
fileAcceptString = getAcceptStringForFileType(fileType);
} else {
fileAcceptString = undefined;
}
// Use setTimeout to ensure the accept attribute is applied before opening dialog
setTimeout(() => {
fileInputRef?.click();
}, 10);
function handleFileUpload() {
fileInputRef?.click();
}
async function handleKeydown(event: KeyboardEvent) {
@@ -343,13 +268,7 @@
});
</script>
<ChatFormFileInputInvisible
bind:this={fileInputRef}
bind:accept={fileAcceptString}
{hasAudioModality}
{hasVisionModality}
onFileSelect={handleFileSelect}
/>
<ChatFormFileInputInvisible bind:this={fileInputRef} onFileSelect={handleFileSelect} />
<form
onsubmit={handleSubmit}

View File

@@ -4,14 +4,13 @@
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import { FILE_TYPE_ICONS } from '$lib/constants/icons';
import { FileTypeCategory } from '$lib/enums';
interface Props {
class?: string;
disabled?: boolean;
hasAudioModality?: boolean;
hasVisionModality?: boolean;
onFileUpload?: (fileType?: FileTypeCategory) => void;
onFileUpload?: () => void;
}
let {
@@ -27,10 +26,6 @@
? 'Text files and PDFs supported. Images, audio, and video require vision models.'
: 'Attach files';
});
function handleFileUpload(fileType?: FileTypeCategory) {
onFileUpload?.(fileType);
}
</script>
<div class="flex items-center gap-1 {className}">
@@ -61,7 +56,7 @@
<DropdownMenu.Item
class="images-button flex cursor-pointer items-center gap-2"
disabled={!hasVisionModality}
onclick={() => handleFileUpload(FileTypeCategory.IMAGE)}
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.image class="h-4 w-4" />
@@ -81,7 +76,7 @@
<DropdownMenu.Item
class="audio-button flex cursor-pointer items-center gap-2"
disabled={!hasAudioModality}
onclick={() => handleFileUpload(FileTypeCategory.AUDIO)}
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
@@ -98,7 +93,7 @@
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => handleFileUpload(FileTypeCategory.TEXT)}
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.text class="h-4 w-4" />
@@ -109,7 +104,7 @@
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => handleFileUpload(FileTypeCategory.PDF)}
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />

View File

@@ -24,7 +24,7 @@
isRecording?: boolean;
hasText?: boolean;
uploadedFiles?: ChatUploadedFile[];
onFileUpload?: (fileType?: FileTypeCategory) => void;
onFileUpload?: () => void;
onMicClick?: () => void;
onStop?: () => void;
}

View File

@@ -1,35 +1,14 @@
<script lang="ts">
import { generateModalityAwareAcceptString } from '$lib/utils';
interface Props {
accept?: string;
class?: string;
hasAudioModality?: boolean;
hasVisionModality?: boolean;
multiple?: boolean;
onFileSelect?: (files: File[]) => void;
}
let {
accept = $bindable(),
class: className = '',
hasAudioModality = false,
hasVisionModality = false,
multiple = true,
onFileSelect
}: Props = $props();
let { class: className = '', multiple = true, onFileSelect }: Props = $props();
let fileInputElement: HTMLInputElement | undefined;
// Use modality-aware accept string by default, but allow override
let finalAccept = $derived(
accept ??
generateModalityAwareAcceptString({
hasVision: hasVisionModality,
hasAudio: hasAudioModality
})
);
export function click() {
fileInputElement?.click();
}
@@ -46,7 +25,6 @@
bind:this={fileInputElement}
type="file"
{multiple}
accept={finalAccept}
onchange={handleFileSelect}
class="hidden {className}"
/>

View File

@@ -195,9 +195,28 @@ export function getFileTypeByExtension(filename: string): string | null {
}
export function isFileTypeSupported(filename: string, mimeType?: string): boolean {
if (mimeType && getFileTypeCategory(mimeType)) {
// Images are detected and handled separately for vision models
if (mimeType) {
const category = getFileTypeCategory(mimeType);
if (
category === FileTypeCategory.IMAGE ||
category === FileTypeCategory.AUDIO ||
category === FileTypeCategory.PDF
) {
return true;
}
}
// Check extension for known types (especially images without MIME)
const extCategory = getFileTypeCategoryByExtension(filename);
if (
extCategory === FileTypeCategory.IMAGE ||
extCategory === FileTypeCategory.AUDIO ||
extCategory === FileTypeCategory.PDF
) {
return true;
}
return getFileTypeByExtension(filename) !== null;
// Fallback: treat everything else as text (inclusive by default)
return true;
}

View File

@@ -76,7 +76,6 @@ export {
isFileTypeSupportedByModel,
filterFilesByModalities,
generateModalityErrorMessage,
generateModalityAwareAcceptString,
type ModalityCapabilities
} from './modality-file-validation';

View File

@@ -4,17 +4,7 @@
*/
import { getFileTypeCategory } from '$lib/utils';
import {
FileExtensionAudio,
FileExtensionImage,
FileExtensionPdf,
FileExtensionText,
MimeTypeAudio,
MimeTypeImage,
MimeTypeApplication,
MimeTypeText,
FileTypeCategory
} from '$lib/enums';
import { FileTypeCategory } from '$lib/enums';
/** Modality capabilities for file validation */
export interface ModalityCapabilities {
@@ -170,29 +160,3 @@ export function generateModalityErrorMessage(
* @param capabilities - The modality capabilities to check against
* @returns Accept string for HTML file input element
*/
export function generateModalityAwareAcceptString(capabilities: ModalityCapabilities): string {
const { hasVision, hasAudio } = capabilities;
const acceptedExtensions: string[] = [];
const acceptedMimeTypes: string[] = [];
// Always include text files and PDFs
acceptedExtensions.push(...Object.values(FileExtensionText));
acceptedMimeTypes.push(...Object.values(MimeTypeText));
acceptedExtensions.push(...Object.values(FileExtensionPdf));
acceptedMimeTypes.push(...Object.values(MimeTypeApplication));
// Include images only if vision is supported
if (hasVision) {
acceptedExtensions.push(...Object.values(FileExtensionImage));
acceptedMimeTypes.push(...Object.values(MimeTypeImage));
}
// Include audio only if audio is supported
if (hasAudio) {
acceptedExtensions.push(...Object.values(FileExtensionAudio));
acceptedMimeTypes.push(...Object.values(MimeTypeAudio));
}
return [...acceptedExtensions, ...acceptedMimeTypes].join(',');
}

View File

@@ -1,5 +1,4 @@
import { isSvgMimeType, svgBase64UrlToPngDataURL } from './svg-to-png';
import { isTextFileByName } from './text-files';
import { isWebpMimeType, webpBase64UrlToPngDataURL } from './webp-to-png';
import { FileTypeCategory } from '$lib/enums';
import { modelsStore } from '$lib/stores/models.svelte';
@@ -84,17 +83,6 @@ export async function processFilesToChatUploaded(
}
results.push({ ...base, preview });
} else if (
getFileTypeCategory(file.type) === FileTypeCategory.TEXT ||
isTextFileByName(file.name)
) {
try {
const textContent = await readFileAsUTF8(file);
results.push({ ...base, textContent });
} catch (err) {
console.warn('Failed to read text file, adding without content:', err);
results.push(base);
}
} else if (getFileTypeCategory(file.type) === FileTypeCategory.PDF) {
// Extract text content from PDF for preview
try {
@@ -129,8 +117,14 @@ export async function processFilesToChatUploaded(
const preview = await readFileAsDataURL(file);
results.push({ ...base, preview });
} else {
// Other files: add as-is
results.push(base);
// Fallback: treat unknown files as text
try {
const textContent = await readFileAsUTF8(file);
results.push({ ...base, textContent });
} catch (err) {
console.warn('Failed to read file as text, adding without content:', err);
results.push(base);
}
}
} catch (error) {
console.error('Error processing file', file.name, error);

View File

@@ -65,10 +65,7 @@
await expect(textarea).toHaveValue(text);
const fileInput = document.querySelector('input[type="file"]');
const acceptAttr = fileInput?.getAttribute('accept');
await expect(fileInput).toHaveAttribute('accept');
await expect(acceptAttr).not.toContain('image/');
await expect(acceptAttr).not.toContain('audio/');
await expect(fileInput).not.toHaveAttribute('accept');
// Open file attachments dropdown
const fileUploadButton = canvas.getByText('Attach files');