handle models that need to be downloaded before estimation

This commit is contained in:
Ruben Ortlam
2026-04-20 14:48:55 +02:00
parent 1a8aec0afd
commit b1623a614c
5 changed files with 151 additions and 9 deletions

View File

@@ -604,12 +604,33 @@ void server_models::unload_lru(const device_memory_map & dmm_req) {
}
}
static std::string resolve_model_path(const common_preset & preset) {
common_params params;
preset.apply_to_params(params);
if (!params.model.path.empty()) {
return params.model.path;
}
if (!params.model.hf_repo.empty() || !params.model.url.empty()) {
common_download_opts opts;
opts.offline = true;
auto result = common_download_model(params.model, opts);
return result.model_path;
}
return "";
}
static device_memory_map get_model_memory_per_device(const common_preset & preset) {
common_params params;
preset.apply_to_params(params);
if (params.model.path.empty()) {
return {};
if(params.model.path.empty()) {
params.model.path = resolve_model_path(preset);
if(params.model.path.empty()) {
return {};
}
}
struct log_ud_t {
@@ -661,11 +682,98 @@ static device_memory_map get_model_memory_per_device(const common_preset & prese
return result;
}
bool server_models::download_model(const std::string & name) {
std::vector<std::string> child_args;
std::vector<std::string> child_env;
{
std::lock_guard<std::mutex> lk(mutex);
auto & meta = mapping[name].meta;
child_args = meta.preset.to_args(bin_path);
child_env = base_env;
}
child_args.push_back("--download-only");
SRV_INF("downloading model name=%s\n", name.c_str());
std::vector<char *> argv = to_char_ptr_array(child_args);
std::vector<char *> envp = to_char_ptr_array(child_env);
subprocess_s proc;
int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
if (subprocess_create_ex(argv.data(), options, envp.data(), &proc) != 0) {
SRV_ERR("failed to spawn download process for model name=%s\n", name.c_str());
return false;
}
FILE * out = subprocess_stdout(&proc);
if (out) {
char buffer[4096];
while (fgets(buffer, sizeof(buffer), out) != nullptr) {
LOG("[dl:%s] %s", name.c_str(), buffer);
}
}
int exit_code = 0;
subprocess_join(&proc, &exit_code);
subprocess_destroy(&proc);
if (exit_code != 0) {
SRV_ERR("download process for model name=%s exited with code %d\n", name.c_str(), exit_code);
return false;
}
SRV_INF("download complete for model name=%s\n", name.c_str());
return true;
}
void server_models::load(const std::string & name) {
if (!has_model(name)) {
throw std::runtime_error("model name=" + name + " is not found");
}
{
common_preset preset_copy;
{
std::lock_guard<std::mutex> lk(mutex);
preset_copy = mapping[name].meta.preset;
}
if (resolve_model_path(preset_copy).empty()) {
{
std::lock_guard<std::mutex> lk(mutex);
auto & meta = mapping[name].meta;
if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
return;
}
meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
cv.notify_all();
}
std::thread([this, name]() {
if (!download_model(name)) {
update_status(name, SERVER_MODEL_STATUS_UNLOADED, 1);
return;
}
device_memory_map mem;
if (base_params.models_memory_margin > 0) {
std::lock_guard<std::mutex> lk(mutex);
auto & meta = mapping[name].meta;
meta.dmm_req = get_model_memory_per_device(meta.preset);
if (meta.dmm_req.empty()) {
SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str());
}
mem = meta.dmm_req;
}
update_status(name, SERVER_MODEL_STATUS_UNLOADED, 0);
try {
_load(name, mem);
} catch (const std::exception & e) {
SRV_ERR("failed to load model %s after download: %s\n", name.c_str(), e.what());
update_status(name, SERVER_MODEL_STATUS_UNLOADED, 1);
}
}).detach();
return;
}
}
device_memory_map dmm_req;
if (base_params.models_memory_margin > 0) {
// determine the required memory by the model upon its first load
@@ -673,11 +781,18 @@ void server_models::load(const std::string & name) {
auto & meta = mapping[name].meta;
if (meta.dmm_req.empty()) {
meta.dmm_req = get_model_memory_per_device(meta.preset);
if (meta.dmm_req.empty()) {
SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str());
}
}
dmm_req = meta.dmm_req;
}
_load(name, dmm_req);
}
void server_models::_load(const std::string & name, const device_memory_map & dmm_req) {
unload_lru(dmm_req);
std::lock_guard<std::mutex> lk(mutex);
@@ -913,7 +1028,8 @@ void server_models::wait_until_loading_finished(const std::string & name) {
cv.wait(lk, [this, &name]() {
auto it = mapping.find(name);
if (it != mapping.end()) {
return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
return it->second.meta.status != SERVER_MODEL_STATUS_LOADING &&
it->second.meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
}
return false;
});