mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-11 19:44:06 +00:00
llama-eval : support multiple evaluation endpoints with dynamic task distribution
- Add ServerConfig dataclass (url, threads, name) - Accept comma-separated --server, --threads, --server-name CLI args - Dynamic shared-queue task distribution across servers (fast servers do more work) - One ThreadPoolExecutor per server, workers pull from shared Queue - Track which server processed each task (server_name in results) - Thread-safe EvalState with threading.Lock for concurrent mutations - Server column in HTML report and console output - Backward compatible: single server works as before Assisted-by: llama.cpp:local pi
This commit is contained in:
committed by
Georgi Gerganov
parent
d26b1ffcc9
commit
43f14a0a46
@@ -7,17 +7,26 @@ import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
from math import sqrt
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfig:
|
||||
url: str
|
||||
threads: int
|
||||
name: str = ""
|
||||
|
||||
def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, float]:
|
||||
"""Wilson score confidence interval for a proportion."""
|
||||
if total == 0:
|
||||
@@ -126,6 +135,7 @@ class TaskState:
|
||||
tps_gen: Optional[float] = None
|
||||
t_gen_ms: Optional[float] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
server_name: Optional[str] = None
|
||||
|
||||
|
||||
class EvalState:
|
||||
@@ -146,6 +156,7 @@ class EvalState:
|
||||
self.correct = 0
|
||||
self.processed = 0
|
||||
self.total_time: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load_dataset(self, seed: int = 1234):
|
||||
if self.dataset_type == "aime":
|
||||
@@ -204,33 +215,37 @@ class EvalState:
|
||||
tokens: Optional[int] = None,
|
||||
tps_gen: Optional[float] = None,
|
||||
t_gen_ms: Optional[float] = None,
|
||||
reasoning_content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None,
|
||||
server_name: Optional[str] = None
|
||||
):
|
||||
if "cases" not in self.task_states:
|
||||
self.task_states["cases"] = {}
|
||||
with self._lock:
|
||||
if "cases" not in self.task_states:
|
||||
self.task_states["cases"] = {}
|
||||
|
||||
self.task_states["cases"][task_id] = {
|
||||
"task_id": task_id,
|
||||
"prompt": prompt,
|
||||
"expected": expected,
|
||||
"response": response,
|
||||
"answer": answer,
|
||||
"grader_log": grader_log,
|
||||
"correct": correct,
|
||||
"status": status,
|
||||
"tokens": tokens,
|
||||
"tps_gen": tps_gen,
|
||||
"t_gen_ms": t_gen_ms,
|
||||
"reasoning_content": reasoning_content
|
||||
}
|
||||
self.task_states["cases"][task_id] = {
|
||||
"task_id": task_id,
|
||||
"prompt": prompt,
|
||||
"expected": expected,
|
||||
"response": response,
|
||||
"answer": answer,
|
||||
"grader_log": grader_log,
|
||||
"correct": correct,
|
||||
"status": status,
|
||||
"tokens": tokens,
|
||||
"tps_gen": tps_gen,
|
||||
"t_gen_ms": t_gen_ms,
|
||||
"reasoning_content": reasoning_content,
|
||||
"server_name": server_name
|
||||
}
|
||||
|
||||
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
|
||||
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
|
||||
|
||||
def print_progress(self, task_state: TaskState, total_tasks: int, n_correct: int = 0):
|
||||
display_answer = task_state.answer if task_state.answer else "N/A"
|
||||
display_tokens = str(task_state.tokens) if task_state.tokens is not None else "N/A"
|
||||
display_tps = f"{task_state.tps_gen:.1f}" if task_state.tps_gen is not None else "N/A"
|
||||
display_t_gen = f"{task_state.t_gen_ms/1000:.1f}" if task_state.t_gen_ms is not None else "N/A"
|
||||
display_server = task_state.server_name if task_state.server_name else "N/A"
|
||||
success_ratio = n_correct / self.processed if self.processed > 0 else 0.0
|
||||
first_line = task_state.question_text.split('\n')[0]
|
||||
truncated_question = first_line[:43]
|
||||
@@ -238,7 +253,7 @@ class EvalState:
|
||||
truncated_question += "..."
|
||||
else:
|
||||
truncated_question = truncated_question.ljust(43) + "..."
|
||||
print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {display_answer:<10} {display_tokens:<6} {display_tps:<6} {display_t_gen:<8} {'✓' if task_state.correct else '✗'} [{n_correct:3}/{self.processed:3}, {success_ratio:.3f}]")
|
||||
print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {display_answer:<10} {display_tokens:<6} {display_tps:<6} {display_t_gen:<8} {'✓' if task_state.correct else '✗'} [{n_correct:3}/{self.processed:3}, {success_ratio:.3f}] {display_server}")
|
||||
|
||||
def print_summary(self):
|
||||
if self.total == 0:
|
||||
@@ -272,7 +287,8 @@ class EvalState:
|
||||
"tokens": None,
|
||||
"tps_gen": None,
|
||||
"t_gen_ms": None,
|
||||
"reasoning_content": None
|
||||
"reasoning_content": None,
|
||||
"server_name": None
|
||||
}
|
||||
|
||||
ci_lower, ci_upper = self.accuracy_ci()
|
||||
@@ -339,11 +355,13 @@ class EvalState:
|
||||
t_gen_ms = case.get("t_gen_ms")
|
||||
t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else ""
|
||||
reasoning_content = case.get("reasoning_content", "") or ""
|
||||
server_name = case.get("server_name", "") or ""
|
||||
|
||||
escaped_response = self._escape_html(response)
|
||||
escaped_prompt = self._escape_html(prompt)
|
||||
escaped_reasoning = self._escape_html(reasoning_content)
|
||||
grader_log_str = self._escape_html(json.dumps(grader_log, indent=2))
|
||||
escaped_server = self._escape_html(server_name)
|
||||
|
||||
rows.append(f"""<tr class="task-row" onclick="toggleDetails('{task_id}')">
|
||||
<td>{task_id}</td>
|
||||
@@ -353,9 +371,10 @@ class EvalState:
|
||||
<td>{tokens_str}</td>
|
||||
<td>{tps_str}</td>
|
||||
<td>{t_gen_str}</td>
|
||||
<td>{escaped_server}</td>
|
||||
</tr>
|
||||
<tr id="details-{task_id}" class="details-row">
|
||||
<td colspan="7">
|
||||
<td colspan="8">
|
||||
<div class="details-content">
|
||||
<h4>Prompt</h4>
|
||||
<pre>{escaped_prompt}</pre>
|
||||
@@ -425,6 +444,7 @@ class EvalState:
|
||||
<th>Tokens</th>
|
||||
<th>T/s</th>
|
||||
<th>Gen s</th>
|
||||
<th>Server</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
@@ -527,6 +547,7 @@ class EvalState:
|
||||
tps_str = f"{tps_gen:.1f}" if tps_gen is not None else "N/A"
|
||||
t_gen_ms = case.get("t_gen_ms")
|
||||
t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else "N/A"
|
||||
server_name = case.get("server_name", "") or ""
|
||||
is_correct = case.get("correct", False) if status == "ok" else False
|
||||
symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "")
|
||||
first_line = question.split('\n')[0]
|
||||
@@ -535,7 +556,7 @@ class EvalState:
|
||||
question_trunc += "..."
|
||||
else:
|
||||
question_trunc = question_trunc.ljust(43) + "..."
|
||||
print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {tps_str:<6} {t_gen_str:<8} {symbol}{status}")
|
||||
print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {tps_str:<6} {t_gen_str:<8} {symbol}{status} {server_name}")
|
||||
print()
|
||||
|
||||
def print_existing_summary(self):
|
||||
@@ -947,20 +968,20 @@ Please provide only the extracted answer, nothing else. If there is no clear ans
|
||||
class Processor:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
server_configs: List[ServerConfig],
|
||||
grader: Grader,
|
||||
model_name: Optional[str] = None,
|
||||
threads: int = 32,
|
||||
n_predict: int = -1
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.server_configs = server_configs
|
||||
self.grader = grader
|
||||
self.model_name = model_name
|
||||
self.threads = threads
|
||||
self.n_predict = n_predict
|
||||
|
||||
def _make_request(self, eval_state: EvalState, prompt: str) -> Tuple[Dict[str, Any], int, Optional[float], Optional[float], str]:
|
||||
url = f"{self.server_url}/v1/chat/completions"
|
||||
def _make_request(
|
||||
self, server_config: ServerConfig, eval_state: EvalState, prompt: str
|
||||
) -> Tuple[Dict[str, Any], int, Optional[float], Optional[float], str]:
|
||||
url = f"{server_config.url}/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"model": self.model_name if self.model_name else "llama",
|
||||
@@ -986,18 +1007,21 @@ class Processor:
|
||||
finish_reason = result.get("choices", [{}])[0].get("finish_reason", "stop")
|
||||
return result, tokens, tps_gen, t_gen_ms, finish_reason
|
||||
|
||||
def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState:
|
||||
def _process_single_case(
|
||||
self, server_config: ServerConfig, eval_state: EvalState, i: int, task_id: str
|
||||
) -> TaskState:
|
||||
question_text, prompt, expected = eval_state.get_case(i)
|
||||
|
||||
task_state = TaskState(
|
||||
task_id=task_id,
|
||||
prompt=prompt,
|
||||
expected=expected,
|
||||
question_text=question_text
|
||||
question_text=question_text,
|
||||
server_name=server_config.name
|
||||
)
|
||||
|
||||
try:
|
||||
response, tokens, tps_gen, t_gen_ms, finish_reason = self._make_request(eval_state, prompt)
|
||||
response, tokens, tps_gen, t_gen_ms, finish_reason = self._make_request(server_config, eval_state, prompt)
|
||||
result = response["choices"][0]["message"]["content"]
|
||||
reasoning_content = response["choices"][0].get("message", {}).get("reasoning_content")
|
||||
task_state.response = result
|
||||
@@ -1008,7 +1032,11 @@ class Processor:
|
||||
|
||||
if finish_reason != "stop":
|
||||
task_state.status = f"error: finish_reason={finish_reason}"
|
||||
eval_state.add_result(task_id, prompt, expected, result, None, {"finish_reason": finish_reason}, False, task_state.status, tokens, tps_gen, t_gen_ms, reasoning_content)
|
||||
eval_state.add_result(
|
||||
task_id, prompt, expected, result, None,
|
||||
{"finish_reason": finish_reason}, False, task_state.status,
|
||||
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
|
||||
)
|
||||
eval_state.dump()
|
||||
return task_state
|
||||
|
||||
@@ -1027,7 +1055,11 @@ class Processor:
|
||||
task_state.grader_log = grader_log
|
||||
task_state.status = "ok"
|
||||
|
||||
eval_state.add_result(task_id, prompt, expected, result, answer, grader_log, is_correct, "ok", tokens, tps_gen, t_gen_ms, reasoning_content)
|
||||
eval_state.add_result(
|
||||
task_id, prompt, expected, result, answer,
|
||||
grader_log, is_correct, "ok",
|
||||
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
|
||||
)
|
||||
|
||||
eval_state.dump()
|
||||
|
||||
@@ -1036,47 +1068,106 @@ class Processor:
|
||||
|
||||
return task_state
|
||||
|
||||
@staticmethod
|
||||
def _worker(
|
||||
server_config: ServerConfig,
|
||||
processor: "Processor",
|
||||
eval_state: EvalState,
|
||||
task_queue: Queue,
|
||||
results_queue: Queue,
|
||||
):
|
||||
"""Worker that pulls tasks from a shared queue and sends them to its server."""
|
||||
while True:
|
||||
task = task_queue.get()
|
||||
if task is None: # sentinel
|
||||
task_queue.task_done()
|
||||
break
|
||||
try:
|
||||
i, task_id = task
|
||||
result = processor._process_single_case(server_config, eval_state, i, task_id)
|
||||
results_queue.put(result)
|
||||
finally:
|
||||
task_queue.task_done()
|
||||
|
||||
def evaluate(self, eval_state: EvalState, verbose: bool = False, resume: bool = False):
|
||||
total_tasks = len(eval_state.tasks)
|
||||
eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks
|
||||
eval_state.processed = 0
|
||||
start_time = time.time()
|
||||
|
||||
# Print server info
|
||||
server_lines = [
|
||||
f" {i+1}. {sc.name} — {sc.url} ({sc.threads} threads)"
|
||||
for i, sc in enumerate(self.server_configs)
|
||||
]
|
||||
print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} tasks ...")
|
||||
print(f"Server: {self.server_url} (model: {self.model_name})")
|
||||
print(f"Servers ({len(self.server_configs)}):")
|
||||
for line in server_lines:
|
||||
print(line)
|
||||
print(f"Model: {self.model_name}")
|
||||
print(f"Grader: {self.grader.grader_type}")
|
||||
print(f"Threads: {self.threads}")
|
||||
print(f"Sampling: temp={eval_state.sampling_config.get('temperature', 'skip')}, top-k={eval_state.sampling_config.get('top_k', 'skip')}, top-p={eval_state.sampling_config.get('top_p', 'skip')}, min-p={eval_state.sampling_config.get('min_p', 'skip')}")
|
||||
print()
|
||||
|
||||
# Shared task queue: all workers compete for tasks
|
||||
task_queue: Queue = Queue()
|
||||
for i, task_id in eval_state.tasks:
|
||||
task_queue.put((i, task_id))
|
||||
|
||||
# Results queue: workers push completed TaskStates here
|
||||
results_queue: Queue = Queue()
|
||||
|
||||
# Total worker threads across all servers
|
||||
total_threads = sum(sc.threads for sc in self.server_configs)
|
||||
|
||||
# Add one sentinel per worker so every worker exits cleanly
|
||||
for _ in range(total_threads):
|
||||
task_queue.put(None)
|
||||
|
||||
# Launch workers: one ThreadPoolExecutor per server
|
||||
executors: List[ThreadPoolExecutor] = []
|
||||
worker_futures: List[Any] = []
|
||||
for server_config in self.server_configs:
|
||||
executor = ThreadPoolExecutor(max_workers=server_config.threads)
|
||||
executors.append(executor)
|
||||
for _ in range(server_config.threads):
|
||||
future = executor.submit(
|
||||
self._worker, server_config, self, eval_state,
|
||||
task_queue, results_queue
|
||||
)
|
||||
worker_futures.append(future)
|
||||
|
||||
# Drain results as they complete
|
||||
n_correct = 0
|
||||
session_time = 0.0
|
||||
completed_count = 0
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.threads) as executor:
|
||||
futures = {
|
||||
executor.submit(self._process_single_case, eval_state, i, task_id): (i, task_id)
|
||||
for i, task_id in eval_state.tasks
|
||||
}
|
||||
while completed_count < total_tasks:
|
||||
task_state = results_queue.get()
|
||||
eval_state.processed += 1
|
||||
completed_count += 1
|
||||
if task_state.correct:
|
||||
n_correct += 1
|
||||
elapsed = time.time() - start_time
|
||||
eval_state.total_time += elapsed
|
||||
session_time += elapsed
|
||||
start_time = time.time()
|
||||
eval_state.print_progress(task_state, total_tasks, n_correct)
|
||||
|
||||
session_time = 0.0
|
||||
for future in as_completed(futures):
|
||||
task_state = future.result()
|
||||
eval_state.processed += 1
|
||||
if task_state.correct:
|
||||
n_correct += 1
|
||||
elapsed = time.time() - start_time
|
||||
eval_state.total_time += elapsed
|
||||
session_time += elapsed
|
||||
start_time = time.time()
|
||||
eval_state.print_progress(task_state, total_tasks, n_correct)
|
||||
if verbose:
|
||||
print(f"\nCase {eval_state.processed}: {task_state.correct}")
|
||||
print(f" Expected: {task_state.expected}")
|
||||
if task_state.response:
|
||||
print(f" Response: {task_state.response}")
|
||||
if task_state.answer:
|
||||
print(f" Answer: {task_state.answer}")
|
||||
print(f" Status: {task_state.status}")
|
||||
|
||||
if verbose:
|
||||
print(f"\nCase {eval_state.processed}: {task_state.correct}")
|
||||
print(f" Expected: {task_state.expected}")
|
||||
if task_state.response:
|
||||
print(f" Response: {task_state.response}")
|
||||
if task_state.answer:
|
||||
print(f" Answer: {task_state.answer}")
|
||||
print(f" Status: {task_state.status}")
|
||||
# Wait for all workers to finish and shut down executors
|
||||
for future in worker_futures:
|
||||
future.result()
|
||||
for executor in executors:
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
print(f"\nSession time: {session_time:.1f}s | Total accumulated time: {eval_state.total_time:.1f}s")
|
||||
eval_state.print_summary()
|
||||
@@ -1090,7 +1181,13 @@ def main():
|
||||
"--server",
|
||||
type=str,
|
||||
default="http://localhost:8033",
|
||||
help="llama-server URL (default: http://localhost:8033)"
|
||||
help="Comma-separated llama-server URLs (default: http://localhost:8033)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-name",
|
||||
type=str,
|
||||
default="",
|
||||
help="Comma-separated display names for servers (default: use URLs)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
@@ -1143,9 +1240,9 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threads",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of threads for parallel requests (default: 32)"
|
||||
type=str,
|
||||
default="32",
|
||||
help="Comma-separated thread counts per server (default: 32)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
@@ -1197,6 +1294,28 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse server URLs and thread counts
|
||||
server_urls = [u.strip() for u in args.server.split(",") if u.strip()]
|
||||
thread_counts = [int(t.strip()) for t in args.threads.split(",") if t.strip()]
|
||||
|
||||
if len(server_urls) != len(thread_counts):
|
||||
print(f"Error: --server ({len(server_urls)} URLs) and --threads ({len(thread_counts)} values) must have the same count")
|
||||
sys.exit(1)
|
||||
|
||||
# Parse server names (optional, defaults to URLs)
|
||||
if args.server_name:
|
||||
server_names = [n.strip() for n in args.server_name.split(",") if n.strip()]
|
||||
if len(server_names) != len(server_urls):
|
||||
print(f"Error: --server-name ({len(server_names)} names) and --server ({len(server_urls)} URLs) must have the same count")
|
||||
sys.exit(1)
|
||||
else:
|
||||
server_names = server_urls # fallback to URLs
|
||||
|
||||
server_configs = [
|
||||
ServerConfig(url=url, threads=threads, name=name)
|
||||
for url, threads, name in zip(server_urls, thread_counts, server_names)
|
||||
]
|
||||
|
||||
if args.dataset == "gpqa" and args.grader_type != "llm":
|
||||
print("Error: GPQA dataset requires --grader-type llm")
|
||||
parser.print_help()
|
||||
@@ -1226,7 +1345,7 @@ def main():
|
||||
eval_state.tasks = pending_tasks
|
||||
eval_state.task_states["cases"] = existing_cases
|
||||
|
||||
grader_server_url = args.grader_server if args.grader_server else args.server
|
||||
grader_server_url = args.grader_server if args.grader_server else server_configs[0].url
|
||||
grader_model_name = args.grader_model if args.grader_model else args.model
|
||||
grader = Grader(
|
||||
grader_type=args.grader_type,
|
||||
@@ -1241,7 +1360,7 @@ def main():
|
||||
print("Error: No existing eval state found to resume")
|
||||
sys.exit(1)
|
||||
|
||||
grader_server_url = args.grader_server if args.grader_server else args.server
|
||||
grader_server_url = args.grader_server if args.grader_server else server_configs[0].url
|
||||
grader_model_name = args.grader_model if args.grader_model else args.model
|
||||
|
||||
grader = Grader(
|
||||
@@ -1278,10 +1397,9 @@ def main():
|
||||
eval_state.print_all_tasks()
|
||||
|
||||
processor = Processor(
|
||||
server_url=args.server,
|
||||
server_configs=server_configs,
|
||||
grader=grader,
|
||||
model_name=args.model,
|
||||
threads=args.threads,
|
||||
n_predict=args.n_predict
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user