diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 9c9b337933..07f6e9f573 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -7,17 +7,26 @@ import os import re import subprocess import sys +import threading import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, asdict, field from pathlib import Path +from queue import Queue from typing import Dict, List, Optional, Any, Tuple import requests from tqdm import tqdm import random from math import sqrt + +@dataclass +class ServerConfig: + url: str + threads: int + name: str = "" + def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, float]: """Wilson score confidence interval for a proportion.""" if total == 0: @@ -126,6 +135,7 @@ class TaskState: tps_gen: Optional[float] = None t_gen_ms: Optional[float] = None reasoning_content: Optional[str] = None + server_name: Optional[str] = None class EvalState: @@ -146,6 +156,7 @@ class EvalState: self.correct = 0 self.processed = 0 self.total_time: float = 0.0 + self._lock = threading.Lock() def load_dataset(self, seed: int = 1234): if self.dataset_type == "aime": @@ -204,33 +215,37 @@ class EvalState: tokens: Optional[int] = None, tps_gen: Optional[float] = None, t_gen_ms: Optional[float] = None, - reasoning_content: Optional[str] = None + reasoning_content: Optional[str] = None, + server_name: Optional[str] = None ): - if "cases" not in self.task_states: - self.task_states["cases"] = {} + with self._lock: + if "cases" not in self.task_states: + self.task_states["cases"] = {} - self.task_states["cases"][task_id] = { - "task_id": task_id, - "prompt": prompt, - "expected": expected, - "response": response, - "answer": answer, - "grader_log": grader_log, - "correct": correct, - "status": status, - "tokens": tokens, - "tps_gen": tps_gen, - "t_gen_ms": t_gen_ms, - "reasoning_content": reasoning_content - } + self.task_states["cases"][task_id] = { + "task_id": task_id, + "prompt": prompt, + "expected": expected, + "response": response, + "answer": answer, + "grader_log": grader_log, + "correct": correct, + "status": status, + "tokens": tokens, + "tps_gen": tps_gen, + "t_gen_ms": t_gen_ms, + "reasoning_content": reasoning_content, + "server_name": server_name + } - self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False)) + self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False)) def print_progress(self, task_state: TaskState, total_tasks: int, n_correct: int = 0): display_answer = task_state.answer if task_state.answer else "N/A" display_tokens = str(task_state.tokens) if task_state.tokens is not None else "N/A" display_tps = f"{task_state.tps_gen:.1f}" if task_state.tps_gen is not None else "N/A" display_t_gen = f"{task_state.t_gen_ms/1000:.1f}" if task_state.t_gen_ms is not None else "N/A" + display_server = task_state.server_name if task_state.server_name else "N/A" success_ratio = n_correct / self.processed if self.processed > 0 else 0.0 first_line = task_state.question_text.split('\n')[0] truncated_question = first_line[:43] @@ -238,7 +253,7 @@ class EvalState: truncated_question += "..." else: truncated_question = truncated_question.ljust(43) + "..." - print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {display_answer:<10} {display_tokens:<6} {display_tps:<6} {display_t_gen:<8} {'✓' if task_state.correct else '✗'} [{n_correct:3}/{self.processed:3}, {success_ratio:.3f}]") + print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {display_answer:<10} {display_tokens:<6} {display_tps:<6} {display_t_gen:<8} {'✓' if task_state.correct else '✗'} [{n_correct:3}/{self.processed:3}, {success_ratio:.3f}] {display_server}") def print_summary(self): if self.total == 0: @@ -272,7 +287,8 @@ class EvalState: "tokens": None, "tps_gen": None, "t_gen_ms": None, - "reasoning_content": None + "reasoning_content": None, + "server_name": None } ci_lower, ci_upper = self.accuracy_ci() @@ -339,11 +355,13 @@ class EvalState: t_gen_ms = case.get("t_gen_ms") t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else "" reasoning_content = case.get("reasoning_content", "") or "" + server_name = case.get("server_name", "") or "" escaped_response = self._escape_html(response) escaped_prompt = self._escape_html(prompt) escaped_reasoning = self._escape_html(reasoning_content) grader_log_str = self._escape_html(json.dumps(grader_log, indent=2)) + escaped_server = self._escape_html(server_name) rows.append(f""" {task_id} @@ -353,9 +371,10 @@ class EvalState: {tokens_str} {tps_str} {t_gen_str} + {escaped_server} - +

Prompt

{escaped_prompt}
@@ -425,6 +444,7 @@ class EvalState: Tokens T/s Gen s + Server @@ -527,6 +547,7 @@ class EvalState: tps_str = f"{tps_gen:.1f}" if tps_gen is not None else "N/A" t_gen_ms = case.get("t_gen_ms") t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else "N/A" + server_name = case.get("server_name", "") or "" is_correct = case.get("correct", False) if status == "ok" else False symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "") first_line = question.split('\n')[0] @@ -535,7 +556,7 @@ class EvalState: question_trunc += "..." else: question_trunc = question_trunc.ljust(43) + "..." - print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {tps_str:<6} {t_gen_str:<8} {symbol}{status}") + print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {tps_str:<6} {t_gen_str:<8} {symbol}{status} {server_name}") print() def print_existing_summary(self): @@ -947,20 +968,20 @@ Please provide only the extracted answer, nothing else. If there is no clear ans class Processor: def __init__( self, - server_url: str, + server_configs: List[ServerConfig], grader: Grader, model_name: Optional[str] = None, - threads: int = 32, n_predict: int = -1 ): - self.server_url = server_url + self.server_configs = server_configs self.grader = grader self.model_name = model_name - self.threads = threads self.n_predict = n_predict - def _make_request(self, eval_state: EvalState, prompt: str) -> Tuple[Dict[str, Any], int, Optional[float], Optional[float], str]: - url = f"{self.server_url}/v1/chat/completions" + def _make_request( + self, server_config: ServerConfig, eval_state: EvalState, prompt: str + ) -> Tuple[Dict[str, Any], int, Optional[float], Optional[float], str]: + url = f"{server_config.url}/v1/chat/completions" headers = {"Content-Type": "application/json"} data = { "model": self.model_name if self.model_name else "llama", @@ -986,18 +1007,21 @@ class Processor: finish_reason = result.get("choices", [{}])[0].get("finish_reason", "stop") return result, tokens, tps_gen, t_gen_ms, finish_reason - def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState: + def _process_single_case( + self, server_config: ServerConfig, eval_state: EvalState, i: int, task_id: str + ) -> TaskState: question_text, prompt, expected = eval_state.get_case(i) task_state = TaskState( task_id=task_id, prompt=prompt, expected=expected, - question_text=question_text + question_text=question_text, + server_name=server_config.name ) try: - response, tokens, tps_gen, t_gen_ms, finish_reason = self._make_request(eval_state, prompt) + response, tokens, tps_gen, t_gen_ms, finish_reason = self._make_request(server_config, eval_state, prompt) result = response["choices"][0]["message"]["content"] reasoning_content = response["choices"][0].get("message", {}).get("reasoning_content") task_state.response = result @@ -1008,7 +1032,11 @@ class Processor: if finish_reason != "stop": task_state.status = f"error: finish_reason={finish_reason}" - eval_state.add_result(task_id, prompt, expected, result, None, {"finish_reason": finish_reason}, False, task_state.status, tokens, tps_gen, t_gen_ms, reasoning_content) + eval_state.add_result( + task_id, prompt, expected, result, None, + {"finish_reason": finish_reason}, False, task_state.status, + tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name + ) eval_state.dump() return task_state @@ -1027,7 +1055,11 @@ class Processor: task_state.grader_log = grader_log task_state.status = "ok" - eval_state.add_result(task_id, prompt, expected, result, answer, grader_log, is_correct, "ok", tokens, tps_gen, t_gen_ms, reasoning_content) + eval_state.add_result( + task_id, prompt, expected, result, answer, + grader_log, is_correct, "ok", + tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name + ) eval_state.dump() @@ -1036,47 +1068,106 @@ class Processor: return task_state + @staticmethod + def _worker( + server_config: ServerConfig, + processor: "Processor", + eval_state: EvalState, + task_queue: Queue, + results_queue: Queue, + ): + """Worker that pulls tasks from a shared queue and sends them to its server.""" + while True: + task = task_queue.get() + if task is None: # sentinel + task_queue.task_done() + break + try: + i, task_id = task + result = processor._process_single_case(server_config, eval_state, i, task_id) + results_queue.put(result) + finally: + task_queue.task_done() + def evaluate(self, eval_state: EvalState, verbose: bool = False, resume: bool = False): total_tasks = len(eval_state.tasks) eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks eval_state.processed = 0 start_time = time.time() + # Print server info + server_lines = [ + f" {i+1}. {sc.name} — {sc.url} ({sc.threads} threads)" + for i, sc in enumerate(self.server_configs) + ] print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} tasks ...") - print(f"Server: {self.server_url} (model: {self.model_name})") + print(f"Servers ({len(self.server_configs)}):") + for line in server_lines: + print(line) + print(f"Model: {self.model_name}") print(f"Grader: {self.grader.grader_type}") - print(f"Threads: {self.threads}") print(f"Sampling: temp={eval_state.sampling_config.get('temperature', 'skip')}, top-k={eval_state.sampling_config.get('top_k', 'skip')}, top-p={eval_state.sampling_config.get('top_p', 'skip')}, min-p={eval_state.sampling_config.get('min_p', 'skip')}") print() + # Shared task queue: all workers compete for tasks + task_queue: Queue = Queue() + for i, task_id in eval_state.tasks: + task_queue.put((i, task_id)) + + # Results queue: workers push completed TaskStates here + results_queue: Queue = Queue() + + # Total worker threads across all servers + total_threads = sum(sc.threads for sc in self.server_configs) + + # Add one sentinel per worker so every worker exits cleanly + for _ in range(total_threads): + task_queue.put(None) + + # Launch workers: one ThreadPoolExecutor per server + executors: List[ThreadPoolExecutor] = [] + worker_futures: List[Any] = [] + for server_config in self.server_configs: + executor = ThreadPoolExecutor(max_workers=server_config.threads) + executors.append(executor) + for _ in range(server_config.threads): + future = executor.submit( + self._worker, server_config, self, eval_state, + task_queue, results_queue + ) + worker_futures.append(future) + + # Drain results as they complete n_correct = 0 + session_time = 0.0 + completed_count = 0 - with ThreadPoolExecutor(max_workers=self.threads) as executor: - futures = { - executor.submit(self._process_single_case, eval_state, i, task_id): (i, task_id) - for i, task_id in eval_state.tasks - } + while completed_count < total_tasks: + task_state = results_queue.get() + eval_state.processed += 1 + completed_count += 1 + if task_state.correct: + n_correct += 1 + elapsed = time.time() - start_time + eval_state.total_time += elapsed + session_time += elapsed + start_time = time.time() + eval_state.print_progress(task_state, total_tasks, n_correct) - session_time = 0.0 - for future in as_completed(futures): - task_state = future.result() - eval_state.processed += 1 - if task_state.correct: - n_correct += 1 - elapsed = time.time() - start_time - eval_state.total_time += elapsed - session_time += elapsed - start_time = time.time() - eval_state.print_progress(task_state, total_tasks, n_correct) + if verbose: + print(f"\nCase {eval_state.processed}: {task_state.correct}") + print(f" Expected: {task_state.expected}") + if task_state.response: + print(f" Response: {task_state.response}") + if task_state.answer: + print(f" Answer: {task_state.answer}") + print(f" Status: {task_state.status}") - if verbose: - print(f"\nCase {eval_state.processed}: {task_state.correct}") - print(f" Expected: {task_state.expected}") - if task_state.response: - print(f" Response: {task_state.response}") - if task_state.answer: - print(f" Answer: {task_state.answer}") - print(f" Status: {task_state.status}") + # Wait for all workers to finish and shut down executors + for future in worker_futures: + future.result() + for executor in executors: + executor.shutdown(wait=True) print(f"\nSession time: {session_time:.1f}s | Total accumulated time: {eval_state.total_time:.1f}s") eval_state.print_summary() @@ -1090,7 +1181,13 @@ def main(): "--server", type=str, default="http://localhost:8033", - help="llama-server URL (default: http://localhost:8033)" + help="Comma-separated llama-server URLs (default: http://localhost:8033)" + ) + parser.add_argument( + "--server-name", + type=str, + default="", + help="Comma-separated display names for servers (default: use URLs)" ) parser.add_argument( "--dataset", @@ -1143,9 +1240,9 @@ def main(): ) parser.add_argument( "--threads", - type=int, - default=32, - help="Number of threads for parallel requests (default: 32)" + type=str, + default="32", + help="Comma-separated thread counts per server (default: 32)" ) parser.add_argument( "--model", @@ -1197,6 +1294,28 @@ def main(): args = parser.parse_args() + # Parse server URLs and thread counts + server_urls = [u.strip() for u in args.server.split(",") if u.strip()] + thread_counts = [int(t.strip()) for t in args.threads.split(",") if t.strip()] + + if len(server_urls) != len(thread_counts): + print(f"Error: --server ({len(server_urls)} URLs) and --threads ({len(thread_counts)} values) must have the same count") + sys.exit(1) + + # Parse server names (optional, defaults to URLs) + if args.server_name: + server_names = [n.strip() for n in args.server_name.split(",") if n.strip()] + if len(server_names) != len(server_urls): + print(f"Error: --server-name ({len(server_names)} names) and --server ({len(server_urls)} URLs) must have the same count") + sys.exit(1) + else: + server_names = server_urls # fallback to URLs + + server_configs = [ + ServerConfig(url=url, threads=threads, name=name) + for url, threads, name in zip(server_urls, thread_counts, server_names) + ] + if args.dataset == "gpqa" and args.grader_type != "llm": print("Error: GPQA dataset requires --grader-type llm") parser.print_help() @@ -1226,7 +1345,7 @@ def main(): eval_state.tasks = pending_tasks eval_state.task_states["cases"] = existing_cases - grader_server_url = args.grader_server if args.grader_server else args.server + grader_server_url = args.grader_server if args.grader_server else server_configs[0].url grader_model_name = args.grader_model if args.grader_model else args.model grader = Grader( grader_type=args.grader_type, @@ -1241,7 +1360,7 @@ def main(): print("Error: No existing eval state found to resume") sys.exit(1) - grader_server_url = args.grader_server if args.grader_server else args.server + grader_server_url = args.grader_server if args.grader_server else server_configs[0].url grader_model_name = args.grader_model if args.grader_model else args.model grader = Grader( @@ -1278,10 +1397,9 @@ def main(): eval_state.print_all_tasks() processor = Processor( - server_url=args.server, + server_configs=server_configs, grader=grader, model_name=args.model, - threads=args.threads, n_predict=args.n_predict )