mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 12:34:05 +00:00
handle models that need to be downloaded before estimation
This commit is contained in:
@@ -3308,6 +3308,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.offline = true;
|
||||
}
|
||||
).set_env("LLAMA_OFFLINE"));
|
||||
add_opt(common_arg(
|
||||
{"--download-only"},
|
||||
"Download the model file(s) and exit",
|
||||
[](common_params & params) {
|
||||
params.download_only = true;
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"-lv", "--verbosity", "--log-verbosity"}, "N",
|
||||
string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n"
|
||||
|
||||
@@ -482,6 +482,7 @@ struct common_params {
|
||||
int32_t control_vector_layer_start = -1; // layer range for control vector
|
||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||
bool offline = false;
|
||||
bool download_only = false; // only download the model if required, don't start the server
|
||||
|
||||
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
|
||||
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
|
||||
|
||||
@@ -604,12 +604,33 @@ void server_models::unload_lru(const device_memory_map & dmm_req) {
|
||||
}
|
||||
}
|
||||
|
||||
static std::string resolve_model_path(const common_preset & preset) {
|
||||
common_params params;
|
||||
preset.apply_to_params(params);
|
||||
|
||||
if (!params.model.path.empty()) {
|
||||
return params.model.path;
|
||||
}
|
||||
|
||||
if (!params.model.hf_repo.empty() || !params.model.url.empty()) {
|
||||
common_download_opts opts;
|
||||
opts.offline = true;
|
||||
auto result = common_download_model(params.model, opts);
|
||||
return result.model_path;
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
static device_memory_map get_model_memory_per_device(const common_preset & preset) {
|
||||
common_params params;
|
||||
preset.apply_to_params(params);
|
||||
|
||||
if (params.model.path.empty()) {
|
||||
return {};
|
||||
if(params.model.path.empty()) {
|
||||
params.model.path = resolve_model_path(preset);
|
||||
if(params.model.path.empty()) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
struct log_ud_t {
|
||||
@@ -661,11 +682,98 @@ static device_memory_map get_model_memory_per_device(const common_preset & prese
|
||||
return result;
|
||||
}
|
||||
|
||||
bool server_models::download_model(const std::string & name) {
|
||||
std::vector<std::string> child_args;
|
||||
std::vector<std::string> child_env;
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto & meta = mapping[name].meta;
|
||||
child_args = meta.preset.to_args(bin_path);
|
||||
child_env = base_env;
|
||||
}
|
||||
child_args.push_back("--download-only");
|
||||
|
||||
SRV_INF("downloading model name=%s\n", name.c_str());
|
||||
|
||||
std::vector<char *> argv = to_char_ptr_array(child_args);
|
||||
std::vector<char *> envp = to_char_ptr_array(child_env);
|
||||
|
||||
subprocess_s proc;
|
||||
int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
|
||||
if (subprocess_create_ex(argv.data(), options, envp.data(), &proc) != 0) {
|
||||
SRV_ERR("failed to spawn download process for model name=%s\n", name.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
FILE * out = subprocess_stdout(&proc);
|
||||
if (out) {
|
||||
char buffer[4096];
|
||||
while (fgets(buffer, sizeof(buffer), out) != nullptr) {
|
||||
LOG("[dl:%s] %s", name.c_str(), buffer);
|
||||
}
|
||||
}
|
||||
|
||||
int exit_code = 0;
|
||||
subprocess_join(&proc, &exit_code);
|
||||
subprocess_destroy(&proc);
|
||||
|
||||
if (exit_code != 0) {
|
||||
SRV_ERR("download process for model name=%s exited with code %d\n", name.c_str(), exit_code);
|
||||
return false;
|
||||
}
|
||||
|
||||
SRV_INF("download complete for model name=%s\n", name.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
void server_models::load(const std::string & name) {
|
||||
if (!has_model(name)) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
|
||||
{
|
||||
common_preset preset_copy;
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
preset_copy = mapping[name].meta.preset;
|
||||
}
|
||||
if (resolve_model_path(preset_copy).empty()) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto & meta = mapping[name].meta;
|
||||
if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
|
||||
return;
|
||||
}
|
||||
meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
cv.notify_all();
|
||||
}
|
||||
std::thread([this, name]() {
|
||||
if (!download_model(name)) {
|
||||
update_status(name, SERVER_MODEL_STATUS_UNLOADED, 1);
|
||||
return;
|
||||
}
|
||||
device_memory_map mem;
|
||||
if (base_params.models_memory_margin > 0) {
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto & meta = mapping[name].meta;
|
||||
meta.dmm_req = get_model_memory_per_device(meta.preset);
|
||||
if (meta.dmm_req.empty()) {
|
||||
SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str());
|
||||
}
|
||||
mem = meta.dmm_req;
|
||||
}
|
||||
update_status(name, SERVER_MODEL_STATUS_UNLOADED, 0);
|
||||
try {
|
||||
_load(name, mem);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to load model %s after download: %s\n", name.c_str(), e.what());
|
||||
update_status(name, SERVER_MODEL_STATUS_UNLOADED, 1);
|
||||
}
|
||||
}).detach();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
device_memory_map dmm_req;
|
||||
if (base_params.models_memory_margin > 0) {
|
||||
// determine the required memory by the model upon its first load
|
||||
@@ -673,11 +781,18 @@ void server_models::load(const std::string & name) {
|
||||
auto & meta = mapping[name].meta;
|
||||
if (meta.dmm_req.empty()) {
|
||||
meta.dmm_req = get_model_memory_per_device(meta.preset);
|
||||
if (meta.dmm_req.empty()) {
|
||||
SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
dmm_req = meta.dmm_req;
|
||||
}
|
||||
|
||||
_load(name, dmm_req);
|
||||
}
|
||||
|
||||
void server_models::_load(const std::string & name, const device_memory_map & dmm_req) {
|
||||
unload_lru(dmm_req);
|
||||
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
@@ -913,7 +1028,8 @@ void server_models::wait_until_loading_finished(const std::string & name) {
|
||||
cv.wait(lk, [this, &name]() {
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
|
||||
return it->second.meta.status != SERVER_MODEL_STATUS_LOADING &&
|
||||
it->second.meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
/**
|
||||
* state diagram:
|
||||
*
|
||||
*
|
||||
* ┌► DOWNLOADING ─┐
|
||||
* │ ▼
|
||||
* UNLOADED ──► LOADING ──► LOADED ◄──── SLEEPING
|
||||
* ▲ │ │ ▲
|
||||
* └───failed───┘ │ │
|
||||
@@ -21,8 +24,8 @@
|
||||
* └────────unloaded─────────┘
|
||||
*/
|
||||
enum server_model_status {
|
||||
// TODO: also add downloading state when the logic is added
|
||||
SERVER_MODEL_STATUS_UNLOADED,
|
||||
SERVER_MODEL_STATUS_DOWNLOADING,
|
||||
SERVER_MODEL_STATUS_LOADING,
|
||||
SERVER_MODEL_STATUS_LOADED,
|
||||
SERVER_MODEL_STATUS_SLEEPING
|
||||
@@ -32,6 +35,9 @@ static server_model_status server_model_status_from_string(const std::string & s
|
||||
if (status_str == "unloaded") {
|
||||
return SERVER_MODEL_STATUS_UNLOADED;
|
||||
}
|
||||
if (status_str == "downloading") {
|
||||
return SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
}
|
||||
if (status_str == "loading") {
|
||||
return SERVER_MODEL_STATUS_LOADING;
|
||||
}
|
||||
@@ -46,11 +52,12 @@ static server_model_status server_model_status_from_string(const std::string & s
|
||||
|
||||
static std::string server_model_status_to_string(server_model_status status) {
|
||||
switch (status) {
|
||||
case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
|
||||
case SERVER_MODEL_STATUS_LOADING: return "loading";
|
||||
case SERVER_MODEL_STATUS_LOADED: return "loaded";
|
||||
case SERVER_MODEL_STATUS_SLEEPING: return "sleeping";
|
||||
default: return "unknown";
|
||||
case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
|
||||
case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading";
|
||||
case SERVER_MODEL_STATUS_LOADING: return "loading";
|
||||
case SERVER_MODEL_STATUS_LOADED: return "loaded";
|
||||
case SERVER_MODEL_STATUS_SLEEPING: return "sleeping";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,6 +133,12 @@ private:
|
||||
// not thread-safe, caller must hold mutex
|
||||
int can_fit(const device_memory_map & dmm_req) const;
|
||||
|
||||
// download model files, blocking call (caller must NOT hold mutex)
|
||||
bool download_model(const std::string & name);
|
||||
|
||||
// Internal helper for model loading
|
||||
void _load(const std::string & name, const device_memory_map & dmm_req);
|
||||
|
||||
public:
|
||||
server_models(const common_params & params, int argc, char ** argv);
|
||||
|
||||
|
||||
@@ -83,6 +83,11 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.download_only) {
|
||||
LOG_INF("%s: model downloaded successfully, exiting\n", __func__);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// validate batch size for embeddings
|
||||
// embeddings require all tokens to be processed in a single ubatch
|
||||
// see https://github.com/ggml-org/llama.cpp/issues/12836
|
||||
|
||||
Reference in New Issue
Block a user