From bfc135fee2f01136773466e9f92e1bd8dfaf4307 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 30 Apr 2026 11:55:13 +0200 Subject: [PATCH] server: support Vertex AI compatible API --- tools/server/server-http.cpp | 156 +++++++++++++++++++++++ tools/server/server-http.h | 11 +- tools/server/server.cpp | 4 + tools/server/tests/unit/test_vertexai.py | 59 +++++++++ 4 files changed, 227 insertions(+), 3 deletions(-) create mode 100644 tools/server/tests/unit/test_vertexai.py diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 6f24f83ef3..509fc2ddae 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -4,7 +4,9 @@ #include +#include #include +#include #include #include @@ -420,6 +422,7 @@ static void process_handler_response(server_http_req_ptr && request, server_http } void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const { + handlers.emplace(path, handler); pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { server_http_req_ptr request = std::make_unique(server_http_req{ get_params(req), @@ -436,6 +439,7 @@ void server_http_context::get(const std::string & path, const server_http_contex } void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const { + handlers.emplace(path, handler); pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { std::string body = req.body; std::map files; @@ -481,3 +485,155 @@ void server_http_context::post(const std::string & path, const server_http_conte }); } +// +// Vertex AI Prediction protocol (AIP_PREDICT_ROUTE) +// https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements +// + +static std::string get_vertexai_predict_route() { + const char * value = std::getenv("AIP_PREDICT_ROUTE"); + if (value == nullptr || value[0] == '\0') { + return "/predict"; + } + std::string route = value; + if (route[0] != '/') { + route.insert(route.begin(), '/'); + } + return route; +} + +// Derives the camelCase @requestFormat alias for a registered path. +// e.g. "/v1/chat/completions" -> "chatCompletions", "/apply-template" -> "applyTemplate" +static std::string path_to_vertexai_format(const std::string & path) { + std::string s = path; + if (s.size() > 3 && s[0] == '/' && s[1] == 'v' && s[2] == '1') { + s = s.substr(3); + } + if (!s.empty() && s[0] == '/') { + s = s.substr(1); + } + std::string result; + bool cap = false; + for (unsigned char c : s) { + if (c == ':') break; // stop before path parameters + if (c == '/' || c == '-' || c == '_') { + cap = true; + } else { + result += cap ? (char)std::toupper(c) : (char)c; + cap = false; + } + } + return result; +} + +static json parse_vertexai_predict_response(const server_http_res_ptr & res) { + if (res == nullptr) { + throw std::runtime_error("empty response from internal handler"); + } + if (res->is_stream()) { + throw std::invalid_argument("predict route does not support streaming responses"); + } + if (res->data.empty()) { + return nullptr; + } + try { + return json::parse(res->data); + } catch (...) { + return res->data; + } +} + +void server_http_context::register_vertexai() { + const std::string route = get_vertexai_predict_route(); + + if (handlers.count(route)) { + LOG_ERR("%s: AIP_PREDICT_ROUTE=%s conflicts with an existing llama-server route\n", __func__, route.c_str()); + exit(1); + } + + // camelCase alias -> canonical path (first registration wins on collision) + // e.g. "chatCompletions" -> "/v1/chat/completions" + std::unordered_map alias_to_path; + for (const auto & [path, _] : handlers) { + alias_to_path.emplace(path_to_vertexai_format(path), path); + } + + post(route, [this, alias_to_path = std::move(alias_to_path)](const server_http_req & req) -> server_http_res_ptr { + static const auto build_error = [](const std::string & message, error_type type = ERROR_TYPE_INVALID_REQUEST) -> json { + return json {{"error", format_error_response(message, type)}}; + }; + + json data = json::parse(req.body); + if (!data.is_object()) { + throw std::invalid_argument("request body must be a JSON object"); + } + if (!data.contains("instances") || !data.at("instances").is_array()) { + throw std::invalid_argument("request body must include an array field named instances"); + } + + const json & instances = data.at("instances"); + std::vector> futures; + futures.reserve(instances.size()); + + for (const auto & instance : instances) { + futures.push_back(std::async(std::launch::async, [this, &req, &alias_to_path, instance]() -> json { + if (!instance.is_object()) { + return build_error("each instance must be a JSON object"); + } + if (!instance.contains("@requestFormat") || !instance.at("@requestFormat").is_string()) { + return build_error("each instance must include a string @requestFormat"); + } + + try { + json payload = instance; + const std::string format = payload.at("@requestFormat").get(); + payload.erase("@requestFormat"); + + if (payload.contains("stream")) { + LOG_WRN("%s: ignoring client-provided stream field in instance, streaming is not supported in predict route\n", __func__); + payload["stream"] = false; + } + + // accept both camelCase aliases (e.g. "chatCompletions") and direct paths + std::string dispatch_path; + auto it_alias = alias_to_path.find(format); + if (it_alias != alias_to_path.end()) { + dispatch_path = it_alias->second; + } else if (handlers.count(format)) { + dispatch_path = format; + } else { + return build_error("no handler registered for @requestFormat: " + format); + } + + const server_http_req internal_req { + req.params, + req.headers, + path_prefix + dispatch_path, + req.query_string, + payload.dump(), + {}, + req.should_stop, + }; + + server_http_res_ptr internal_res = handlers.at(dispatch_path)(internal_req); + return parse_vertexai_predict_response(internal_res); + } catch (const std::invalid_argument & e) { + return build_error(e.what()); + } catch (const std::exception & e) { + return build_error(e.what(), ERROR_TYPE_SERVER); + } catch (...) { + return build_error("unknown error", ERROR_TYPE_SERVER); + } + })); + } + + json predictions = json::array(); + for (auto & future : futures) { + predictions.push_back(future.get()); + } + + auto res = std::make_unique(); + res->data = safe_json_to_str({{"predictions", predictions}}); + return res; + }); +} diff --git a/tools/server/server-http.h b/tools/server/server-http.h index d4d3b6e536..93f32c68d3 100644 --- a/tools/server/server-http.h +++ b/tools/server/server-http.h @@ -67,6 +67,10 @@ struct server_http_context { std::thread thread; // server thread std::atomic is_ready = false; + // note: the handler should never throw exceptions + using handler_t = std::function; + mutable std::unordered_map handlers; + std::string path_prefix; std::string hostname; int port; @@ -78,12 +82,13 @@ struct server_http_context { bool start(); void stop() const; - // note: the handler should never throw exceptions - using handler_t = std::function; - void get(const std::string & path, const handler_t & handler) const; void post(const std::string & path, const handler_t & handler) const; + // Register the Vertex AI Prediction protocol endpoint (AIP_PREDICT_ROUTE env var, or /predict) + // Must be called AFTER all other routes are registered + void register_vertexai(); + // for debugging std::string listening_address; }; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 6566949edf..90a9f3ebc3 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -204,6 +204,10 @@ int main(int argc, char ** argv) { // Save & load slots ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); + + // Vertex AI Prediction protocol endpoint (AIP_PREDICT_ROUTE, or /predict by default) + ctx_http.register_vertexai(); + // CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP) if (params.webui_mcp_proxy) { SRV_WRN("%s", "-----------------\n"); diff --git a/tools/server/tests/unit/test_vertexai.py b/tools/server/tests/unit/test_vertexai.py new file mode 100644 index 0000000000..73093507ac --- /dev/null +++ b/tools/server/tests/unit/test_vertexai.py @@ -0,0 +1,59 @@ +import pytest +from utils import * + +server: ServerProcess + + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +def test_vertexai_predict_camel_case(): + global server + server.start() + res = server.make_request("POST", "/predict", data={ + "instances": [ + { + "@requestFormat": "chatCompletions", + "max_tokens": 8, + "messages": [ + {"role": "user", "content": "What is the meaning of life?"}, + ], + } + ], + }) + assert res.status_code == 200 + assert "predictions" in res.body + assert len(res.body["predictions"]) == 1 + prediction = res.body["predictions"][0] + assert "choices" in prediction + assert len(prediction["choices"]) == 1 + assert prediction["choices"][0]["message"]["role"] == "assistant" + assert len(prediction["choices"][0]["message"]["content"]) > 0 + + +def test_vertexai_predict_multiple_instances(): + global server + server.n_slots = 2 + server.start() + res = server.make_request("POST", "/predict", data={ + "instances": [ + { + "@requestFormat": "chatCompletions", + "max_tokens": 8, + "messages": [{"role": "user", "content": "Say hello"}], + }, + { + "@requestFormat": "chatCompletions", + "max_tokens": 8, + "messages": [{"role": "user", "content": "Say world"}], + }, + ], + }) + assert res.status_code == 200 + assert len(res.body["predictions"]) == 2 + for prediction in res.body["predictions"]: + assert "choices" in prediction + assert len(prediction["choices"][0]["message"]["content"]) > 0