mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 20:44:09 +00:00
server: support Vertex AI compatible API
This commit is contained in:
@@ -4,7 +4,9 @@
|
||||
|
||||
#include <cpp-httplib/httplib.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
@@ -420,6 +422,7 @@ static void process_handler_response(server_http_req_ptr && request, server_http
|
||||
}
|
||||
|
||||
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||
handlers.emplace(path, handler);
|
||||
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
|
||||
get_params(req),
|
||||
@@ -436,6 +439,7 @@ void server_http_context::get(const std::string & path, const server_http_contex
|
||||
}
|
||||
|
||||
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||
handlers.emplace(path, handler);
|
||||
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||
std::string body = req.body;
|
||||
std::map<std::string, uploaded_file> files;
|
||||
@@ -481,3 +485,155 @@ void server_http_context::post(const std::string & path, const server_http_conte
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// Vertex AI Prediction protocol (AIP_PREDICT_ROUTE)
|
||||
// https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements
|
||||
//
|
||||
|
||||
static std::string get_vertexai_predict_route() {
|
||||
const char * value = std::getenv("AIP_PREDICT_ROUTE");
|
||||
if (value == nullptr || value[0] == '\0') {
|
||||
return "/predict";
|
||||
}
|
||||
std::string route = value;
|
||||
if (route[0] != '/') {
|
||||
route.insert(route.begin(), '/');
|
||||
}
|
||||
return route;
|
||||
}
|
||||
|
||||
// Derives the camelCase @requestFormat alias for a registered path.
|
||||
// e.g. "/v1/chat/completions" -> "chatCompletions", "/apply-template" -> "applyTemplate"
|
||||
static std::string path_to_vertexai_format(const std::string & path) {
|
||||
std::string s = path;
|
||||
if (s.size() > 3 && s[0] == '/' && s[1] == 'v' && s[2] == '1') {
|
||||
s = s.substr(3);
|
||||
}
|
||||
if (!s.empty() && s[0] == '/') {
|
||||
s = s.substr(1);
|
||||
}
|
||||
std::string result;
|
||||
bool cap = false;
|
||||
for (unsigned char c : s) {
|
||||
if (c == ':') break; // stop before path parameters
|
||||
if (c == '/' || c == '-' || c == '_') {
|
||||
cap = true;
|
||||
} else {
|
||||
result += cap ? (char)std::toupper(c) : (char)c;
|
||||
cap = false;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static json parse_vertexai_predict_response(const server_http_res_ptr & res) {
|
||||
if (res == nullptr) {
|
||||
throw std::runtime_error("empty response from internal handler");
|
||||
}
|
||||
if (res->is_stream()) {
|
||||
throw std::invalid_argument("predict route does not support streaming responses");
|
||||
}
|
||||
if (res->data.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
try {
|
||||
return json::parse(res->data);
|
||||
} catch (...) {
|
||||
return res->data;
|
||||
}
|
||||
}
|
||||
|
||||
void server_http_context::register_vertexai() {
|
||||
const std::string route = get_vertexai_predict_route();
|
||||
|
||||
if (handlers.count(route)) {
|
||||
LOG_ERR("%s: AIP_PREDICT_ROUTE=%s conflicts with an existing llama-server route\n", __func__, route.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// camelCase alias -> canonical path (first registration wins on collision)
|
||||
// e.g. "chatCompletions" -> "/v1/chat/completions"
|
||||
std::unordered_map<std::string, std::string> alias_to_path;
|
||||
for (const auto & [path, _] : handlers) {
|
||||
alias_to_path.emplace(path_to_vertexai_format(path), path);
|
||||
}
|
||||
|
||||
post(route, [this, alias_to_path = std::move(alias_to_path)](const server_http_req & req) -> server_http_res_ptr {
|
||||
static const auto build_error = [](const std::string & message, error_type type = ERROR_TYPE_INVALID_REQUEST) -> json {
|
||||
return json {{"error", format_error_response(message, type)}};
|
||||
};
|
||||
|
||||
json data = json::parse(req.body);
|
||||
if (!data.is_object()) {
|
||||
throw std::invalid_argument("request body must be a JSON object");
|
||||
}
|
||||
if (!data.contains("instances") || !data.at("instances").is_array()) {
|
||||
throw std::invalid_argument("request body must include an array field named instances");
|
||||
}
|
||||
|
||||
const json & instances = data.at("instances");
|
||||
std::vector<std::future<json>> futures;
|
||||
futures.reserve(instances.size());
|
||||
|
||||
for (const auto & instance : instances) {
|
||||
futures.push_back(std::async(std::launch::async, [this, &req, &alias_to_path, instance]() -> json {
|
||||
if (!instance.is_object()) {
|
||||
return build_error("each instance must be a JSON object");
|
||||
}
|
||||
if (!instance.contains("@requestFormat") || !instance.at("@requestFormat").is_string()) {
|
||||
return build_error("each instance must include a string @requestFormat");
|
||||
}
|
||||
|
||||
try {
|
||||
json payload = instance;
|
||||
const std::string format = payload.at("@requestFormat").get<std::string>();
|
||||
payload.erase("@requestFormat");
|
||||
|
||||
if (payload.contains("stream")) {
|
||||
LOG_WRN("%s: ignoring client-provided stream field in instance, streaming is not supported in predict route\n", __func__);
|
||||
payload["stream"] = false;
|
||||
}
|
||||
|
||||
// accept both camelCase aliases (e.g. "chatCompletions") and direct paths
|
||||
std::string dispatch_path;
|
||||
auto it_alias = alias_to_path.find(format);
|
||||
if (it_alias != alias_to_path.end()) {
|
||||
dispatch_path = it_alias->second;
|
||||
} else if (handlers.count(format)) {
|
||||
dispatch_path = format;
|
||||
} else {
|
||||
return build_error("no handler registered for @requestFormat: " + format);
|
||||
}
|
||||
|
||||
const server_http_req internal_req {
|
||||
req.params,
|
||||
req.headers,
|
||||
path_prefix + dispatch_path,
|
||||
req.query_string,
|
||||
payload.dump(),
|
||||
{},
|
||||
req.should_stop,
|
||||
};
|
||||
|
||||
server_http_res_ptr internal_res = handlers.at(dispatch_path)(internal_req);
|
||||
return parse_vertexai_predict_response(internal_res);
|
||||
} catch (const std::invalid_argument & e) {
|
||||
return build_error(e.what());
|
||||
} catch (const std::exception & e) {
|
||||
return build_error(e.what(), ERROR_TYPE_SERVER);
|
||||
} catch (...) {
|
||||
return build_error("unknown error", ERROR_TYPE_SERVER);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
json predictions = json::array();
|
||||
for (auto & future : futures) {
|
||||
predictions.push_back(future.get());
|
||||
}
|
||||
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->data = safe_json_to_str({{"predictions", predictions}});
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user