diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py
index 9c9b337933..07f6e9f573 100755
--- a/examples/llama-eval/llama-eval.py
+++ b/examples/llama-eval/llama-eval.py
@@ -7,17 +7,26 @@ import os
import re
import subprocess
import sys
+import threading
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, asdict, field
from pathlib import Path
+from queue import Queue
from typing import Dict, List, Optional, Any, Tuple
import requests
from tqdm import tqdm
import random
from math import sqrt
+
+@dataclass
+class ServerConfig:
+ url: str
+ threads: int
+ name: str = ""
+
def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, float]:
"""Wilson score confidence interval for a proportion."""
if total == 0:
@@ -126,6 +135,7 @@ class TaskState:
tps_gen: Optional[float] = None
t_gen_ms: Optional[float] = None
reasoning_content: Optional[str] = None
+ server_name: Optional[str] = None
class EvalState:
@@ -146,6 +156,7 @@ class EvalState:
self.correct = 0
self.processed = 0
self.total_time: float = 0.0
+ self._lock = threading.Lock()
def load_dataset(self, seed: int = 1234):
if self.dataset_type == "aime":
@@ -204,33 +215,37 @@ class EvalState:
tokens: Optional[int] = None,
tps_gen: Optional[float] = None,
t_gen_ms: Optional[float] = None,
- reasoning_content: Optional[str] = None
+ reasoning_content: Optional[str] = None,
+ server_name: Optional[str] = None
):
- if "cases" not in self.task_states:
- self.task_states["cases"] = {}
+ with self._lock:
+ if "cases" not in self.task_states:
+ self.task_states["cases"] = {}
- self.task_states["cases"][task_id] = {
- "task_id": task_id,
- "prompt": prompt,
- "expected": expected,
- "response": response,
- "answer": answer,
- "grader_log": grader_log,
- "correct": correct,
- "status": status,
- "tokens": tokens,
- "tps_gen": tps_gen,
- "t_gen_ms": t_gen_ms,
- "reasoning_content": reasoning_content
- }
+ self.task_states["cases"][task_id] = {
+ "task_id": task_id,
+ "prompt": prompt,
+ "expected": expected,
+ "response": response,
+ "answer": answer,
+ "grader_log": grader_log,
+ "correct": correct,
+ "status": status,
+ "tokens": tokens,
+ "tps_gen": tps_gen,
+ "t_gen_ms": t_gen_ms,
+ "reasoning_content": reasoning_content,
+ "server_name": server_name
+ }
- self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
+ self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
def print_progress(self, task_state: TaskState, total_tasks: int, n_correct: int = 0):
display_answer = task_state.answer if task_state.answer else "N/A"
display_tokens = str(task_state.tokens) if task_state.tokens is not None else "N/A"
display_tps = f"{task_state.tps_gen:.1f}" if task_state.tps_gen is not None else "N/A"
display_t_gen = f"{task_state.t_gen_ms/1000:.1f}" if task_state.t_gen_ms is not None else "N/A"
+ display_server = task_state.server_name if task_state.server_name else "N/A"
success_ratio = n_correct / self.processed if self.processed > 0 else 0.0
first_line = task_state.question_text.split('\n')[0]
truncated_question = first_line[:43]
@@ -238,7 +253,7 @@ class EvalState:
truncated_question += "..."
else:
truncated_question = truncated_question.ljust(43) + "..."
- print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {display_answer:<10} {display_tokens:<6} {display_tps:<6} {display_t_gen:<8} {'✓' if task_state.correct else '✗'} [{n_correct:3}/{self.processed:3}, {success_ratio:.3f}]")
+ print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {display_answer:<10} {display_tokens:<6} {display_tps:<6} {display_t_gen:<8} {'✓' if task_state.correct else '✗'} [{n_correct:3}/{self.processed:3}, {success_ratio:.3f}] {display_server}")
def print_summary(self):
if self.total == 0:
@@ -272,7 +287,8 @@ class EvalState:
"tokens": None,
"tps_gen": None,
"t_gen_ms": None,
- "reasoning_content": None
+ "reasoning_content": None,
+ "server_name": None
}
ci_lower, ci_upper = self.accuracy_ci()
@@ -339,11 +355,13 @@ class EvalState:
t_gen_ms = case.get("t_gen_ms")
t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else ""
reasoning_content = case.get("reasoning_content", "") or ""
+ server_name = case.get("server_name", "") or ""
escaped_response = self._escape_html(response)
escaped_prompt = self._escape_html(prompt)
escaped_reasoning = self._escape_html(reasoning_content)
grader_log_str = self._escape_html(json.dumps(grader_log, indent=2))
+ escaped_server = self._escape_html(server_name)
rows.append(f"""
| {task_id} |
@@ -353,9 +371,10 @@ class EvalState:
{tokens_str} |
{tps_str} |
{t_gen_str} |
+ {escaped_server} |
- |
+ |
Prompt
{escaped_prompt}
@@ -425,6 +444,7 @@ class EvalState:
Tokens |
T/s |
Gen s |
+ Server |
|
@@ -527,6 +547,7 @@ class EvalState:
tps_str = f"{tps_gen:.1f}" if tps_gen is not None else "N/A"
t_gen_ms = case.get("t_gen_ms")
t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else "N/A"
+ server_name = case.get("server_name", "") or ""
is_correct = case.get("correct", False) if status == "ok" else False
symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "")
first_line = question.split('\n')[0]
@@ -535,7 +556,7 @@ class EvalState:
question_trunc += "..."
else:
question_trunc = question_trunc.ljust(43) + "..."
- print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {tps_str:<6} {t_gen_str:<8} {symbol}{status}")
+ print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {tps_str:<6} {t_gen_str:<8} {symbol}{status} {server_name}")
print()
def print_existing_summary(self):
@@ -947,20 +968,20 @@ Please provide only the extracted answer, nothing else. If there is no clear ans
class Processor:
def __init__(
self,
- server_url: str,
+ server_configs: List[ServerConfig],
grader: Grader,
model_name: Optional[str] = None,
- threads: int = 32,
n_predict: int = -1
):
- self.server_url = server_url
+ self.server_configs = server_configs
self.grader = grader
self.model_name = model_name
- self.threads = threads
self.n_predict = n_predict
- def _make_request(self, eval_state: EvalState, prompt: str) -> Tuple[Dict[str, Any], int, Optional[float], Optional[float], str]:
- url = f"{self.server_url}/v1/chat/completions"
+ def _make_request(
+ self, server_config: ServerConfig, eval_state: EvalState, prompt: str
+ ) -> Tuple[Dict[str, Any], int, Optional[float], Optional[float], str]:
+ url = f"{server_config.url}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": self.model_name if self.model_name else "llama",
@@ -986,18 +1007,21 @@ class Processor:
finish_reason = result.get("choices", [{}])[0].get("finish_reason", "stop")
return result, tokens, tps_gen, t_gen_ms, finish_reason
- def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState:
+ def _process_single_case(
+ self, server_config: ServerConfig, eval_state: EvalState, i: int, task_id: str
+ ) -> TaskState:
question_text, prompt, expected = eval_state.get_case(i)
task_state = TaskState(
task_id=task_id,
prompt=prompt,
expected=expected,
- question_text=question_text
+ question_text=question_text,
+ server_name=server_config.name
)
try:
- response, tokens, tps_gen, t_gen_ms, finish_reason = self._make_request(eval_state, prompt)
+ response, tokens, tps_gen, t_gen_ms, finish_reason = self._make_request(server_config, eval_state, prompt)
result = response["choices"][0]["message"]["content"]
reasoning_content = response["choices"][0].get("message", {}).get("reasoning_content")
task_state.response = result
@@ -1008,7 +1032,11 @@ class Processor:
if finish_reason != "stop":
task_state.status = f"error: finish_reason={finish_reason}"
- eval_state.add_result(task_id, prompt, expected, result, None, {"finish_reason": finish_reason}, False, task_state.status, tokens, tps_gen, t_gen_ms, reasoning_content)
+ eval_state.add_result(
+ task_id, prompt, expected, result, None,
+ {"finish_reason": finish_reason}, False, task_state.status,
+ tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
+ )
eval_state.dump()
return task_state
@@ -1027,7 +1055,11 @@ class Processor:
task_state.grader_log = grader_log
task_state.status = "ok"
- eval_state.add_result(task_id, prompt, expected, result, answer, grader_log, is_correct, "ok", tokens, tps_gen, t_gen_ms, reasoning_content)
+ eval_state.add_result(
+ task_id, prompt, expected, result, answer,
+ grader_log, is_correct, "ok",
+ tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
+ )
eval_state.dump()
@@ -1036,47 +1068,106 @@ class Processor:
return task_state
+ @staticmethod
+ def _worker(
+ server_config: ServerConfig,
+ processor: "Processor",
+ eval_state: EvalState,
+ task_queue: Queue,
+ results_queue: Queue,
+ ):
+ """Worker that pulls tasks from a shared queue and sends them to its server."""
+ while True:
+ task = task_queue.get()
+ if task is None: # sentinel
+ task_queue.task_done()
+ break
+ try:
+ i, task_id = task
+ result = processor._process_single_case(server_config, eval_state, i, task_id)
+ results_queue.put(result)
+ finally:
+ task_queue.task_done()
+
def evaluate(self, eval_state: EvalState, verbose: bool = False, resume: bool = False):
total_tasks = len(eval_state.tasks)
eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks
eval_state.processed = 0
start_time = time.time()
+ # Print server info
+ server_lines = [
+ f" {i+1}. {sc.name} — {sc.url} ({sc.threads} threads)"
+ for i, sc in enumerate(self.server_configs)
+ ]
print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} tasks ...")
- print(f"Server: {self.server_url} (model: {self.model_name})")
+ print(f"Servers ({len(self.server_configs)}):")
+ for line in server_lines:
+ print(line)
+ print(f"Model: {self.model_name}")
print(f"Grader: {self.grader.grader_type}")
- print(f"Threads: {self.threads}")
print(f"Sampling: temp={eval_state.sampling_config.get('temperature', 'skip')}, top-k={eval_state.sampling_config.get('top_k', 'skip')}, top-p={eval_state.sampling_config.get('top_p', 'skip')}, min-p={eval_state.sampling_config.get('min_p', 'skip')}")
print()
+ # Shared task queue: all workers compete for tasks
+ task_queue: Queue = Queue()
+ for i, task_id in eval_state.tasks:
+ task_queue.put((i, task_id))
+
+ # Results queue: workers push completed TaskStates here
+ results_queue: Queue = Queue()
+
+ # Total worker threads across all servers
+ total_threads = sum(sc.threads for sc in self.server_configs)
+
+ # Add one sentinel per worker so every worker exits cleanly
+ for _ in range(total_threads):
+ task_queue.put(None)
+
+ # Launch workers: one ThreadPoolExecutor per server
+ executors: List[ThreadPoolExecutor] = []
+ worker_futures: List[Any] = []
+ for server_config in self.server_configs:
+ executor = ThreadPoolExecutor(max_workers=server_config.threads)
+ executors.append(executor)
+ for _ in range(server_config.threads):
+ future = executor.submit(
+ self._worker, server_config, self, eval_state,
+ task_queue, results_queue
+ )
+ worker_futures.append(future)
+
+ # Drain results as they complete
n_correct = 0
+ session_time = 0.0
+ completed_count = 0
- with ThreadPoolExecutor(max_workers=self.threads) as executor:
- futures = {
- executor.submit(self._process_single_case, eval_state, i, task_id): (i, task_id)
- for i, task_id in eval_state.tasks
- }
+ while completed_count < total_tasks:
+ task_state = results_queue.get()
+ eval_state.processed += 1
+ completed_count += 1
+ if task_state.correct:
+ n_correct += 1
+ elapsed = time.time() - start_time
+ eval_state.total_time += elapsed
+ session_time += elapsed
+ start_time = time.time()
+ eval_state.print_progress(task_state, total_tasks, n_correct)
- session_time = 0.0
- for future in as_completed(futures):
- task_state = future.result()
- eval_state.processed += 1
- if task_state.correct:
- n_correct += 1
- elapsed = time.time() - start_time
- eval_state.total_time += elapsed
- session_time += elapsed
- start_time = time.time()
- eval_state.print_progress(task_state, total_tasks, n_correct)
+ if verbose:
+ print(f"\nCase {eval_state.processed}: {task_state.correct}")
+ print(f" Expected: {task_state.expected}")
+ if task_state.response:
+ print(f" Response: {task_state.response}")
+ if task_state.answer:
+ print(f" Answer: {task_state.answer}")
+ print(f" Status: {task_state.status}")
- if verbose:
- print(f"\nCase {eval_state.processed}: {task_state.correct}")
- print(f" Expected: {task_state.expected}")
- if task_state.response:
- print(f" Response: {task_state.response}")
- if task_state.answer:
- print(f" Answer: {task_state.answer}")
- print(f" Status: {task_state.status}")
+ # Wait for all workers to finish and shut down executors
+ for future in worker_futures:
+ future.result()
+ for executor in executors:
+ executor.shutdown(wait=True)
print(f"\nSession time: {session_time:.1f}s | Total accumulated time: {eval_state.total_time:.1f}s")
eval_state.print_summary()
@@ -1090,7 +1181,13 @@ def main():
"--server",
type=str,
default="http://localhost:8033",
- help="llama-server URL (default: http://localhost:8033)"
+ help="Comma-separated llama-server URLs (default: http://localhost:8033)"
+ )
+ parser.add_argument(
+ "--server-name",
+ type=str,
+ default="",
+ help="Comma-separated display names for servers (default: use URLs)"
)
parser.add_argument(
"--dataset",
@@ -1143,9 +1240,9 @@ def main():
)
parser.add_argument(
"--threads",
- type=int,
- default=32,
- help="Number of threads for parallel requests (default: 32)"
+ type=str,
+ default="32",
+ help="Comma-separated thread counts per server (default: 32)"
)
parser.add_argument(
"--model",
@@ -1197,6 +1294,28 @@ def main():
args = parser.parse_args()
+ # Parse server URLs and thread counts
+ server_urls = [u.strip() for u in args.server.split(",") if u.strip()]
+ thread_counts = [int(t.strip()) for t in args.threads.split(",") if t.strip()]
+
+ if len(server_urls) != len(thread_counts):
+ print(f"Error: --server ({len(server_urls)} URLs) and --threads ({len(thread_counts)} values) must have the same count")
+ sys.exit(1)
+
+ # Parse server names (optional, defaults to URLs)
+ if args.server_name:
+ server_names = [n.strip() for n in args.server_name.split(",") if n.strip()]
+ if len(server_names) != len(server_urls):
+ print(f"Error: --server-name ({len(server_names)} names) and --server ({len(server_urls)} URLs) must have the same count")
+ sys.exit(1)
+ else:
+ server_names = server_urls # fallback to URLs
+
+ server_configs = [
+ ServerConfig(url=url, threads=threads, name=name)
+ for url, threads, name in zip(server_urls, thread_counts, server_names)
+ ]
+
if args.dataset == "gpqa" and args.grader_type != "llm":
print("Error: GPQA dataset requires --grader-type llm")
parser.print_help()
@@ -1226,7 +1345,7 @@ def main():
eval_state.tasks = pending_tasks
eval_state.task_states["cases"] = existing_cases
- grader_server_url = args.grader_server if args.grader_server else args.server
+ grader_server_url = args.grader_server if args.grader_server else server_configs[0].url
grader_model_name = args.grader_model if args.grader_model else args.model
grader = Grader(
grader_type=args.grader_type,
@@ -1241,7 +1360,7 @@ def main():
print("Error: No existing eval state found to resume")
sys.exit(1)
- grader_server_url = args.grader_server if args.grader_server else args.server
+ grader_server_url = args.grader_server if args.grader_server else server_configs[0].url
grader_model_name = args.grader_model if args.grader_model else args.model
grader = Grader(
@@ -1278,10 +1397,9 @@ def main():
eval_state.print_all_tasks()
processor = Processor(
- server_url=args.server,
+ server_configs=server_configs,
grader=grader,
model_name=args.model,
- threads=args.threads,
n_predict=args.n_predict
)