From b1623a614c682bad576ab7dc19cf613b2af94e6d Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Mon, 20 Apr 2026 14:48:55 +0200 Subject: [PATCH] handle models that need to be downloaded before estimation --- common/arg.cpp | 7 ++ common/common.h | 1 + tools/server/server-models.cpp | 122 ++++++++++++++++++++++++++++++++- tools/server/server-models.h | 25 +++++-- tools/server/server.cpp | 5 ++ 5 files changed, 151 insertions(+), 9 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 7ba0f2fc25..710955a86f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3308,6 +3308,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.offline = true; } ).set_env("LLAMA_OFFLINE")); + add_opt(common_arg( + {"--download-only"}, + "Download the model file(s) and exit", + [](common_params & params) { + params.download_only = true; + } + )); add_opt(common_arg( {"-lv", "--verbosity", "--log-verbosity"}, "N", string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n" diff --git a/common/common.h b/common/common.h index 2996d35404..066e576650 100644 --- a/common/common.h +++ b/common/common.h @@ -482,6 +482,7 @@ struct common_params { int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; + bool download_only = false; // only download the model if required, don't start the server int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 96a291854d..9f34a8cbc1 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -604,12 +604,33 @@ void server_models::unload_lru(const device_memory_map & dmm_req) { } } +static std::string resolve_model_path(const common_preset & preset) { + common_params params; + preset.apply_to_params(params); + + if (!params.model.path.empty()) { + return params.model.path; + } + + if (!params.model.hf_repo.empty() || !params.model.url.empty()) { + common_download_opts opts; + opts.offline = true; + auto result = common_download_model(params.model, opts); + return result.model_path; + } + + return ""; +} + static device_memory_map get_model_memory_per_device(const common_preset & preset) { common_params params; preset.apply_to_params(params); - if (params.model.path.empty()) { - return {}; + if(params.model.path.empty()) { + params.model.path = resolve_model_path(preset); + if(params.model.path.empty()) { + return {}; + } } struct log_ud_t { @@ -661,11 +682,98 @@ static device_memory_map get_model_memory_per_device(const common_preset & prese return result; } +bool server_models::download_model(const std::string & name) { + std::vector child_args; + std::vector child_env; + { + std::lock_guard lk(mutex); + auto & meta = mapping[name].meta; + child_args = meta.preset.to_args(bin_path); + child_env = base_env; + } + child_args.push_back("--download-only"); + + SRV_INF("downloading model name=%s\n", name.c_str()); + + std::vector argv = to_char_ptr_array(child_args); + std::vector envp = to_char_ptr_array(child_env); + + subprocess_s proc; + int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr; + if (subprocess_create_ex(argv.data(), options, envp.data(), &proc) != 0) { + SRV_ERR("failed to spawn download process for model name=%s\n", name.c_str()); + return false; + } + + FILE * out = subprocess_stdout(&proc); + if (out) { + char buffer[4096]; + while (fgets(buffer, sizeof(buffer), out) != nullptr) { + LOG("[dl:%s] %s", name.c_str(), buffer); + } + } + + int exit_code = 0; + subprocess_join(&proc, &exit_code); + subprocess_destroy(&proc); + + if (exit_code != 0) { + SRV_ERR("download process for model name=%s exited with code %d\n", name.c_str(), exit_code); + return false; + } + + SRV_INF("download complete for model name=%s\n", name.c_str()); + return true; +} + void server_models::load(const std::string & name) { if (!has_model(name)) { throw std::runtime_error("model name=" + name + " is not found"); } + { + common_preset preset_copy; + { + std::lock_guard lk(mutex); + preset_copy = mapping[name].meta.preset; + } + if (resolve_model_path(preset_copy).empty()) { + { + std::lock_guard lk(mutex); + auto & meta = mapping[name].meta; + if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { + return; + } + meta.status = SERVER_MODEL_STATUS_DOWNLOADING; + cv.notify_all(); + } + std::thread([this, name]() { + if (!download_model(name)) { + update_status(name, SERVER_MODEL_STATUS_UNLOADED, 1); + return; + } + device_memory_map mem; + if (base_params.models_memory_margin > 0) { + std::lock_guard lk(mutex); + auto & meta = mapping[name].meta; + meta.dmm_req = get_model_memory_per_device(meta.preset); + if (meta.dmm_req.empty()) { + SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str()); + } + mem = meta.dmm_req; + } + update_status(name, SERVER_MODEL_STATUS_UNLOADED, 0); + try { + _load(name, mem); + } catch (const std::exception & e) { + SRV_ERR("failed to load model %s after download: %s\n", name.c_str(), e.what()); + update_status(name, SERVER_MODEL_STATUS_UNLOADED, 1); + } + }).detach(); + return; + } + } + device_memory_map dmm_req; if (base_params.models_memory_margin > 0) { // determine the required memory by the model upon its first load @@ -673,11 +781,18 @@ void server_models::load(const std::string & name) { auto & meta = mapping[name].meta; if (meta.dmm_req.empty()) { meta.dmm_req = get_model_memory_per_device(meta.preset); + if (meta.dmm_req.empty()) { + SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str()); + } } dmm_req = meta.dmm_req; } + _load(name, dmm_req); +} + +void server_models::_load(const std::string & name, const device_memory_map & dmm_req) { unload_lru(dmm_req); std::lock_guard lk(mutex); @@ -913,7 +1028,8 @@ void server_models::wait_until_loading_finished(const std::string & name) { cv.wait(lk, [this, &name]() { auto it = mapping.find(name); if (it != mapping.end()) { - return it->second.meta.status != SERVER_MODEL_STATUS_LOADING; + return it->second.meta.status != SERVER_MODEL_STATUS_LOADING && + it->second.meta.status != SERVER_MODEL_STATUS_DOWNLOADING; } return false; }); diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 567e716bce..aa6abf7cac 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -14,6 +14,9 @@ /** * state diagram: * + * + * ┌► DOWNLOADING ─┐ + * │ ▼ * UNLOADED ──► LOADING ──► LOADED ◄──── SLEEPING * ▲ │ │ ▲ * └───failed───┘ │ │ @@ -21,8 +24,8 @@ * └────────unloaded─────────┘ */ enum server_model_status { - // TODO: also add downloading state when the logic is added SERVER_MODEL_STATUS_UNLOADED, + SERVER_MODEL_STATUS_DOWNLOADING, SERVER_MODEL_STATUS_LOADING, SERVER_MODEL_STATUS_LOADED, SERVER_MODEL_STATUS_SLEEPING @@ -32,6 +35,9 @@ static server_model_status server_model_status_from_string(const std::string & s if (status_str == "unloaded") { return SERVER_MODEL_STATUS_UNLOADED; } + if (status_str == "downloading") { + return SERVER_MODEL_STATUS_DOWNLOADING; + } if (status_str == "loading") { return SERVER_MODEL_STATUS_LOADING; } @@ -46,11 +52,12 @@ static server_model_status server_model_status_from_string(const std::string & s static std::string server_model_status_to_string(server_model_status status) { switch (status) { - case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; - case SERVER_MODEL_STATUS_LOADING: return "loading"; - case SERVER_MODEL_STATUS_LOADED: return "loaded"; - case SERVER_MODEL_STATUS_SLEEPING: return "sleeping"; - default: return "unknown"; + case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; + case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading"; + case SERVER_MODEL_STATUS_LOADING: return "loading"; + case SERVER_MODEL_STATUS_LOADED: return "loaded"; + case SERVER_MODEL_STATUS_SLEEPING: return "sleeping"; + default: return "unknown"; } } @@ -126,6 +133,12 @@ private: // not thread-safe, caller must hold mutex int can_fit(const device_memory_map & dmm_req) const; + // download model files, blocking call (caller must NOT hold mutex) + bool download_model(const std::string & name); + + // Internal helper for model loading + void _load(const std::string & name, const device_memory_map & dmm_req); + public: server_models(const common_params & params, int argc, char ** argv); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 6566949edf..4ff962b89f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -83,6 +83,11 @@ int main(int argc, char ** argv) { return 1; } + if (params.download_only) { + LOG_INF("%s: model downloaded successfully, exiting\n", __func__); + return 0; + } + // validate batch size for embeddings // embeddings require all tokens to be processed in a single ubatch // see https://github.com/ggml-org/llama.cpp/issues/12836