diff --git a/common/common.h b/common/common.h index 8ac5b9a8bd..2996d35404 100644 --- a/common/common.h +++ b/common/common.h @@ -610,7 +610,7 @@ struct common_params { std::string models_dir = ""; // directory containing models for the router server std::string models_preset = ""; // directory containing model presets for the router server int models_max = 4; // maximum number of models to load simultaneously - int models_memory_margin = 1024; // MB of free memory to preserve per device (0 = disabled) + int models_memory_margin = 1024; // MiB of free memory to preserve per device (0 = disabled) bool models_autoload = true; // automatically load models when requested via the router server bool log_json = false; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index ef6acb57de..96a291854d 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -180,11 +180,11 @@ server_models::server_models( bin_path = get_server_exec_path().string(); } catch (const std::exception & e) { bin_path = argv[0]; - LOG_WRN("failed to get server executable path: %s\n", e.what()); - LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]); + SRV_WRN("failed to get server executable path: %s\n", e.what()); + SRV_WRN("using original argv[0] as fallback: %s\n", argv[0]); } - const uint64_t memory_margin = (uint64_t)base_params.models_memory_margin * 1024 * 1024; + const size_t memory_margin = (size_t) base_params.models_memory_margin * 1024 * 1024; if (memory_margin > 0) { const size_t n_devs = ggml_backend_dev_count(); @@ -193,11 +193,10 @@ server_models::server_models( size_t free, total; ggml_backend_dev_memory(dev, &free, &total); if (total > 0) { - const uint64_t available = (free > memory_margin) ? free - memory_margin : 0; - available_memory_per_device[dev] = available; - SRV_DBG("device %s: available memory after margin=%lu MiB\n", - ggml_backend_dev_name(dev), - (unsigned long)(available / (1024 * 1024))); + const size_t available = (free > memory_margin) ? free - memory_margin : 0; + dmm_available[dev] = available; + SRV_DBG("device %s: available memory after margin=%zu MiB\n", + ggml_backend_dev_name(dev), available / (1024 * 1024)); } } } @@ -518,52 +517,57 @@ std::vector server_models::get_all_meta() { return result; } -uint64_t server_models::get_memory_exceeded(const model_memory_map & new_model_memory_per_device) const { - model_memory_map total_memory_per_device; +int server_models::can_fit(const device_memory_map & dmm_req) const { + device_memory_map dmm_total; for (const auto & m : mapping) { if (m.second.meta.is_running()) { - for (const auto & [key, value] : m.second.meta.memory_usage_per_device) { - total_memory_per_device[key] += value; + for (const auto & [dev, mem] : m.second.meta.dmm_req) { + dmm_total[dev] += mem; } } } - auto get = [](const model_memory_map & m, ggml_backend_dev_t k) { - auto it = m.find(k); - return it != m.end() ? it->second : 0; + auto get = [](const device_memory_map & dmm, ggml_backend_dev_t dev) { + auto it = dmm.find(dev); + return it != dmm.end() ? it->second : 0; }; - size_t count_memory_exceeded = 0; + int res = 0; - for (const auto & [key, limit] : available_memory_per_device) { - const uint64_t total_memory = get(total_memory_per_device, key); - const uint64_t new_memory = get(new_model_memory_per_device, key); - SRV_DBG("device %s: total=%lu MB, new=%lu MB, limit=%lu MB\n", - ggml_backend_dev_name(key), - (unsigned long)(total_memory / (1024 * 1024)), - (unsigned long)(new_memory / (1024 * 1024)), - (unsigned long)(limit / (1024 * 1024))); + for (const auto & [dev, limit] : dmm_available) { + const size_t mem_total = get(dmm_total, dev); + const size_t mem_new = get(dmm_req, dev); - if (total_memory + new_memory > limit) { - count_memory_exceeded++; + SRV_DBG("device %s: total=%zu MiB, new=%zu MiB, limit=%zu MiB\n", + ggml_backend_dev_name(dev), + mem_total / (1024 * 1024), mem_new / (1024 * 1024), limit / (1024 * 1024)); + + if (mem_total + mem_new > limit) { + res++; } } - return count_memory_exceeded; + return res; } -void server_models::unload_lru(const model_memory_map & new_model_memory_per_device) { - const bool check_memory = base_params.models_memory_margin > 0 && !available_memory_per_device.empty(); +void server_models::unload_lru(const device_memory_map & dmm_req) { + const bool check_active = base_params.models_max > 0; + const bool check_memory = base_params.models_memory_margin > 0; - if (base_params.models_max <= 0 && !check_memory) { + if (!check_active && !check_memory) { return; // no limit } + if (check_memory) { + GGML_ASSERT(!dmm_available.empty()); + } + while (true) { - std::string lru_model_name = ""; + std::string lru_model_name; int64_t lru_last_used = ggml_time_ms(); - size_t count_active = 0; - size_t count_memory_exceeded = 0; + + int count_active = 0; + int count_exceed = 0; { std::unique_lock lk(mutex); for (const auto & m : mapping) { @@ -575,14 +579,17 @@ void server_models::unload_lru(const model_memory_map & new_model_memory_per_dev } } } - count_memory_exceeded = get_memory_exceeded(new_model_memory_per_device); + if (check_memory) { + count_exceed = can_fit(dmm_req); + } } - bool count_exceeded = base_params.models_max > 0 && - (count_active + 1) > (size_t)base_params.models_max; - if (!lru_model_name.empty() && (count_exceeded || count_memory_exceeded > 0)) { - SRV_INF("limits reached (count=%zu, memory margin exceeded on %zu device(s)), removing LRU name=%s\n", - count_active, count_memory_exceeded, lru_model_name.c_str()); + const bool active_exceeded = check_active && count_active >= base_params.models_max; + const bool memory_exceeded = check_memory && count_exceed > 0; + + if (!lru_model_name.empty() && (active_exceeded || memory_exceeded)) { + SRV_INF("limits reached (count=%d, memory margin exceeded on %d device(s)), removing LRU name=%s\n", + count_active, count_exceed, lru_model_name.c_str()); unload(lru_model_name); // wait for unload to complete { @@ -597,11 +604,11 @@ void server_models::unload_lru(const model_memory_map & new_model_memory_per_dev } } -static model_memory_map get_model_memory_per_device(const common_preset& preset) { +static device_memory_map get_model_memory_per_device(const common_preset & preset) { common_params params; preset.apply_to_params(params); - if(params.model.path.empty()) { + if (params.model.path.empty()) { return {}; } @@ -641,7 +648,7 @@ static model_memory_map get_model_memory_per_device(const common_preset& preset) return {}; } - model_memory_map result; + device_memory_map result; const size_t n_devs = ggml_backend_dev_count(); for (size_t i = 0; i < n_devs; i++) { ggml_backend_dev_t dev = ggml_backend_dev_get(i); @@ -659,18 +666,19 @@ void server_models::load(const std::string & name) { throw std::runtime_error("model name=" + name + " is not found"); } - model_memory_map new_model_memory_per_device; + device_memory_map dmm_req; if (base_params.models_memory_margin > 0) { + // determine the required memory by the model upon its first load std::lock_guard lk(mutex); auto & meta = mapping[name].meta; - if (meta.memory_usage_per_device.empty()) { - meta.memory_usage_per_device = get_model_memory_per_device(meta.preset); + if (meta.dmm_req.empty()) { + meta.dmm_req = get_model_memory_per_device(meta.preset); } - new_model_memory_per_device = meta.memory_usage_per_device; + dmm_req = meta.dmm_req; } - unload_lru(new_model_memory_per_device); + unload_lru(dmm_req); std::lock_guard lk(mutex); @@ -684,17 +692,24 @@ void server_models::load(const std::string & name) { // exceeding models_max. Without this, the window between unload_lru() // releasing its lock and this lock_guard acquiring allows multiple // threads to each observe capacity and all proceed to load. - if (base_params.models_max > 0 || base_params.models_memory_margin > 0) { - size_t count_active = 0; - for (const auto & m : mapping) { - if (m.second.meta.is_running()) { - count_active++; + { + const bool check_active = base_params.models_max > 0; + const bool check_memory = base_params.models_memory_margin > 0; + + if (check_active || check_memory) { + int count_active = 0; + for (const auto & m : mapping) { + if (m.second.meta.is_running()) { + count_active++; + } + } + + const bool active_exceeded = check_active && count_active >= base_params.models_max; + const bool memory_exceeded = check_memory && can_fit(dmm_req) > 0; + + if (active_exceeded || memory_exceeded) { + throw std::runtime_error("model limit reached, try again later"); } - } - bool count_exceeded = base_params.models_max > 0 && count_active >= (size_t)base_params.models_max; - bool memory_exceeded = get_memory_exceeded(new_model_memory_per_device) > 0; - if (count_exceeded || memory_exceeded) { - throw std::runtime_error("model limit reached, try again later"); } } diff --git a/tools/server/server-models.h b/tools/server/server-models.h index f86cc0b2cc..567e716bce 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -54,7 +54,7 @@ static std::string server_model_status_to_string(server_model_status status) { } } -using model_memory_map = std::map; +using device_memory_map = std::map; struct server_model_meta { common_preset preset; @@ -64,7 +64,7 @@ struct server_model_meta { int port = 0; server_model_status status = SERVER_MODEL_STATUS_UNLOADED; int64_t last_used = 0; // for LRU unloading - model_memory_map memory_usage_per_device; // bytes used per device + device_memory_map dmm_req; // bytes required per device std::vector args; // args passed to the model instance, will be populated by render_args() int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown @@ -111,18 +111,20 @@ private: common_preset base_preset; // base preset from llama-server CLI args // available memory per device - model_memory_map available_memory_per_device; + device_memory_map dmm_available; void update_meta(const std::string & name, const server_model_meta & meta); // unload least recently used models if the limit is reached - void unload_lru(const model_memory_map & new_model_memory_per_device); + void unload_lru(const device_memory_map & dmm_req); // not thread-safe, caller must hold mutex void add_model(server_model_meta && meta); + // return number of devices where the memory limit would be exceeded + // return 0 if the new model would fit on all devices // not thread-safe, caller must hold mutex - uint64_t get_memory_exceeded(const model_memory_map & new_model_memory_per_device) const; + int can_fit(const device_memory_map & dmm_req) const; public: server_models(const common_params & params, int argc, char ** argv);