server: support Vertex AI compatible API

This commit is contained in:
Xuan Son Nguyen
2026-04-30 11:55:13 +02:00
parent 19821178be
commit bfc135fee2
4 changed files with 227 additions and 3 deletions

View File

@@ -4,7 +4,9 @@
#include <cpp-httplib/httplib.h>
#include <cstdlib>
#include <functional>
#include <future>
#include <string>
#include <thread>
@@ -420,6 +422,7 @@ static void process_handler_response(server_http_req_ptr && request, server_http
}
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
handlers.emplace(path, handler);
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
get_params(req),
@@ -436,6 +439,7 @@ void server_http_context::get(const std::string & path, const server_http_contex
}
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
handlers.emplace(path, handler);
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
std::string body = req.body;
std::map<std::string, uploaded_file> files;
@@ -481,3 +485,155 @@ void server_http_context::post(const std::string & path, const server_http_conte
});
}
//
// Vertex AI Prediction protocol (AIP_PREDICT_ROUTE)
// https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements
//
static std::string get_vertexai_predict_route() {
const char * value = std::getenv("AIP_PREDICT_ROUTE");
if (value == nullptr || value[0] == '\0') {
return "/predict";
}
std::string route = value;
if (route[0] != '/') {
route.insert(route.begin(), '/');
}
return route;
}
// Derives the camelCase @requestFormat alias for a registered path.
// e.g. "/v1/chat/completions" -> "chatCompletions", "/apply-template" -> "applyTemplate"
static std::string path_to_vertexai_format(const std::string & path) {
std::string s = path;
if (s.size() > 3 && s[0] == '/' && s[1] == 'v' && s[2] == '1') {
s = s.substr(3);
}
if (!s.empty() && s[0] == '/') {
s = s.substr(1);
}
std::string result;
bool cap = false;
for (unsigned char c : s) {
if (c == ':') break; // stop before path parameters
if (c == '/' || c == '-' || c == '_') {
cap = true;
} else {
result += cap ? (char)std::toupper(c) : (char)c;
cap = false;
}
}
return result;
}
static json parse_vertexai_predict_response(const server_http_res_ptr & res) {
if (res == nullptr) {
throw std::runtime_error("empty response from internal handler");
}
if (res->is_stream()) {
throw std::invalid_argument("predict route does not support streaming responses");
}
if (res->data.empty()) {
return nullptr;
}
try {
return json::parse(res->data);
} catch (...) {
return res->data;
}
}
void server_http_context::register_vertexai() {
const std::string route = get_vertexai_predict_route();
if (handlers.count(route)) {
LOG_ERR("%s: AIP_PREDICT_ROUTE=%s conflicts with an existing llama-server route\n", __func__, route.c_str());
exit(1);
}
// camelCase alias -> canonical path (first registration wins on collision)
// e.g. "chatCompletions" -> "/v1/chat/completions"
std::unordered_map<std::string, std::string> alias_to_path;
for (const auto & [path, _] : handlers) {
alias_to_path.emplace(path_to_vertexai_format(path), path);
}
post(route, [this, alias_to_path = std::move(alias_to_path)](const server_http_req & req) -> server_http_res_ptr {
static const auto build_error = [](const std::string & message, error_type type = ERROR_TYPE_INVALID_REQUEST) -> json {
return json {{"error", format_error_response(message, type)}};
};
json data = json::parse(req.body);
if (!data.is_object()) {
throw std::invalid_argument("request body must be a JSON object");
}
if (!data.contains("instances") || !data.at("instances").is_array()) {
throw std::invalid_argument("request body must include an array field named instances");
}
const json & instances = data.at("instances");
std::vector<std::future<json>> futures;
futures.reserve(instances.size());
for (const auto & instance : instances) {
futures.push_back(std::async(std::launch::async, [this, &req, &alias_to_path, instance]() -> json {
if (!instance.is_object()) {
return build_error("each instance must be a JSON object");
}
if (!instance.contains("@requestFormat") || !instance.at("@requestFormat").is_string()) {
return build_error("each instance must include a string @requestFormat");
}
try {
json payload = instance;
const std::string format = payload.at("@requestFormat").get<std::string>();
payload.erase("@requestFormat");
if (payload.contains("stream")) {
LOG_WRN("%s: ignoring client-provided stream field in instance, streaming is not supported in predict route\n", __func__);
payload["stream"] = false;
}
// accept both camelCase aliases (e.g. "chatCompletions") and direct paths
std::string dispatch_path;
auto it_alias = alias_to_path.find(format);
if (it_alias != alias_to_path.end()) {
dispatch_path = it_alias->second;
} else if (handlers.count(format)) {
dispatch_path = format;
} else {
return build_error("no handler registered for @requestFormat: " + format);
}
const server_http_req internal_req {
req.params,
req.headers,
path_prefix + dispatch_path,
req.query_string,
payload.dump(),
{},
req.should_stop,
};
server_http_res_ptr internal_res = handlers.at(dispatch_path)(internal_req);
return parse_vertexai_predict_response(internal_res);
} catch (const std::invalid_argument & e) {
return build_error(e.what());
} catch (const std::exception & e) {
return build_error(e.what(), ERROR_TYPE_SERVER);
} catch (...) {
return build_error("unknown error", ERROR_TYPE_SERVER);
}
}));
}
json predictions = json::array();
for (auto & future : futures) {
predictions.push_back(future.get());
}
auto res = std::make_unique<server_http_res>();
res->data = safe_json_to_str({{"predictions", predictions}});
return res;
});
}