server: add model management and proxy

This commit is contained in:
Xuan Son Nguyen
2025-11-19 21:23:00 +01:00
parent 10e9780154
commit fc5901a449
4 changed files with 798 additions and 48 deletions

View File

@@ -1,6 +1,7 @@
#include "chat.h"
#include "utils.hpp"
#include "server-http.h"
#include "server-models.h"
#include "arg.h"
#include "common.h"
@@ -4452,6 +4453,8 @@ struct server_routes {
const common_params & params;
server_context & ctx_server;
server_http_context & ctx_http; // for reading is_ready
std::unique_ptr<server_models> models = nullptr;
server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http)
: params(params), ctx_server(ctx_server), ctx_http(ctx_http) {}
@@ -5109,6 +5112,115 @@ public:
return res;
};
//
// endpoints for model management (aka router server)
//
server_http_context::handler_t get_router_props = [this](const server_http_req & req) {
std::string name = req.get_param("model");
if (name.empty()) {
// main instance
auto res = std::make_unique<server_res_generator>(ctx_server);
res->ok({
// TODO: add support for this on web UI
{"role", "router"},
{"max_instances", 4}, // dummy value for testing
// this is a dummy response to make sure webui doesn't break
{"model_alias", "llama-server"},
{"model_path", "none"},
{"default_generation_settings", {
{"params", json{}},
{"n_ctx", 0},
}},
});
return std::unique_ptr<server_http_res>(std::move(res));
}
return proxy_get(req);
};
server_http_context::handler_t proxy_get = [this](const server_http_req & req) {
std::string method = "GET";
std::string name = req.get_param("model");
models->ensure_model_loaded(name);
return models->proxy_request(req, method, name);
};
server_http_context::handler_t proxy_post = [this](const server_http_req & req) {
std::string method = "POST";
json body = json::parse(req.body);
std::string name = json_value(body, "model", std::string());
models->ensure_model_loaded(name);
return models->proxy_request(req, method, name);
};
server_http_context::handler_t post_router_models_load = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json body = json::parse(req.body);
std::string name = json_value(body, "model", std::string());
auto model = models->get_meta(name);
if (!model.has_value()) {
res->error(format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
return res;
}
if (model->status == SERVER_MODEL_STATUS_LOADED) {
res->error(format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
return res;
}
models->load(name);
res->ok({{"success", true}});
return res;
};
// used by child process to notify the router about status change
server_http_context::handler_t post_router_models_status = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json body = json::parse(req.body);
std::string model = json_value(body, "model", std::string());
std::string value = json_value(body, "value", std::string());
models->update_status(model, server_model_status_from_string(value));
res->ok({{"success", true}});
return res;
};
server_http_context::handler_t get_router_models = [this](const server_http_req &) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json models_json = json::array();
auto all_models = models->get_all_meta();
for (const auto & model : all_models) {
models_json.push_back(json {
{"model", model.name},
{"name", model.name},
{"id", model.name},
// TODO: other fields...
{"status", {
{"value", server_model_status_to_string(model.status)}
}},
});
}
res->ok({{"data", models_json}});
return res;
};
server_http_context::handler_t post_router_models_unload = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json body = json::parse(req.body);
std::string name = json_value(body, "model", std::string());
auto model = models->get_meta(name);
if (!model.has_value()) {
res->error(format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
return res;
}
if (model->status != SERVER_MODEL_STATUS_LOADED) {
res->error(format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
return res;
}
models->unload(name);
res->ok({{"success", true}});
return res;
};
private:
std::unique_ptr<server_res_generator> handle_completions_impl(
server_task_type type,
@@ -5502,7 +5614,7 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
};
}
int main(int argc, char ** argv) {
int main(int argc, char ** argv, char ** envp) {
// own arguments required by this example
common_params params;
@@ -5550,6 +5662,36 @@ int main(int argc, char ** argv) {
// register API routes
server_routes routes(params, ctx_server, ctx_http);
bool is_router_server = params.model.path == DEFAULT_MODEL_PATH;
if (is_router_server) {
// setup server instances manager
routes.models.reset(new server_models(params, argc, argv, envp));
// proxy handlers
routes.post_props = routes.proxy_post;
routes.post_completions = routes.proxy_post;
routes.post_completions_oai = routes.proxy_post;
routes.post_chat_completions = routes.proxy_post;
routes.post_infill = routes.proxy_post;
routes.post_embeddings = routes.proxy_post;
routes.post_embeddings_oai = routes.proxy_post;
routes.post_rerank = routes.proxy_post;
routes.post_tokenize = routes.proxy_post;
routes.post_detokenize = routes.proxy_post;
routes.post_apply_template = routes.proxy_post;
routes.get_lora_adapters = routes.proxy_get;
routes.post_lora_adapters = routes.proxy_post;
routes.get_slots = routes.proxy_get;
routes.post_slots = routes.proxy_post;
// custom routes for router
routes.get_props = routes.get_router_props;
routes.get_models = routes.get_router_models;
ctx_http.post("/models/load", ex_wrapper(routes.post_router_models_load));
ctx_http.post("/models/unload", ex_wrapper(routes.post_router_models_unload));
ctx_http.post("/models/status", ex_wrapper(routes.post_router_models_status));
}
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics));
@@ -5587,51 +5729,74 @@ int main(int argc, char ** argv) {
// Start the server
//
// setup clean up function, to be called before exit
auto clean_up = [&ctx_http, &ctx_server]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
ctx_http.stop();
ctx_server.queue_results.terminate();
llama_backend_free();
};
std::function<void()> clean_up;
// start the HTTP server before loading the model to be able to serve /health requests
if (!ctx_http.start()) {
clean_up();
LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
return 1;
}
if (is_router_server) {
LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__);
ctx_http.is_ready.store(true);
// load the model
LOG_INF("%s: loading model\n", __func__);
clean_up = []() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
llama_backend_free();
};
if (!ctx_server.load_model(params)) {
clean_up();
if (ctx_http.thread.joinable()) {
ctx_http.thread.join();
if (!ctx_http.start()) {
clean_up();
LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
return 1;
}
LOG_ERR("%s: exiting due to model loading error\n", __func__);
return 1;
shutdown_handler = [&](int) {
ctx_http.stop();
};
} else {
// setup clean up function, to be called before exit
clean_up = [&ctx_http, &ctx_server]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
ctx_http.stop();
ctx_server.queue_results.terminate();
llama_backend_free();
};
// start the HTTP server before loading the model to be able to serve /health requests
if (!ctx_http.start()) {
clean_up();
LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
return 1;
}
// load the model
LOG_INF("%s: loading model\n", __func__);
if (!ctx_server.load_model(params)) {
clean_up();
if (ctx_http.thread.joinable()) {
ctx_http.thread.join();
}
LOG_ERR("%s: exiting due to model loading error\n", __func__);
return 1;
}
ctx_server.init();
ctx_http.is_ready.store(true);
LOG_INF("%s: model loaded\n", __func__);
ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) {
ctx_server.process_single_task(std::move(task));
});
ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
ctx_server.update_slots();
});
shutdown_handler = [&](int) {
// this will unblock start_loop()
ctx_server.queue_tasks.terminate();
};
}
ctx_server.init();
ctx_http.is_ready.store(true);
LOG_INF("%s: model loaded\n", __func__);
ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) {
ctx_server.process_single_task(std::move(task));
});
ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
ctx_server.update_slots();
});
shutdown_handler = [&](int) {
// this will unblock start_loop()
ctx_server.queue_tasks.terminate();
};
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
@@ -5646,16 +5811,32 @@ int main(int argc, char ** argv) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
LOG_INF("%s: starting the main loop...\n", __func__);
// this call blocks the main thread until queue_tasks.terminate() is called
ctx_server.queue_tasks.start_loop();
if (is_router_server) {
LOG_INF("%s: router server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
ctx_http.is_ready.store(true);
ctx_http.thread.join(); // keep the main thread alive
clean_up();
if (ctx_http.thread.joinable()) {
ctx_http.thread.join();
// when the HTTP server stops, clean up and exit
clean_up();
// TODO @ngxson : why the models are already unloaded without this line?
// routes.models->unload_all();
} else {
LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
LOG_INF("%s: starting the main loop...\n", __func__);
// optionally, notify router server that this instance is ready
server_models::notify_router_server_ready(params.model_alias);
// this call blocks the main thread until queue_tasks.terminate() is called
ctx_server.queue_tasks.start_loop();
clean_up();
if (ctx_http.thread.joinable()) {
ctx_http.thread.join();
}
llama_memory_breakdown_print(ctx_server.ctx);
}
llama_memory_breakdown_print(ctx_server.ctx);
return 0;
}