diff --git a/examples/llama-eval/README.md b/examples/llama-eval/README.md
new file mode 100644
index 0000000000..3c5c35f78f
--- /dev/null
+++ b/examples/llama-eval/README.md
@@ -0,0 +1,26 @@
+# llama-eval
+
+Simple evaluation tool for llama.cpp with support for multiple datasets.
+
+For a full description, usage examples, and sample results, see:
+
+- [PR 21152](https://github.com/ggml-org/llama.cpp/pull/21152)
+
+## Quick start
+
+```bash
+# Single server
+python3 llama-eval.py \
+ --server http://localhost:8033 \
+ --model my-model \
+ --dataset gsm8k --n_cases 100 \
+ --grader-type regex --threads 32
+
+# Multiple servers (comma-separated URLs and thread counts)
+python3 llama-eval.py \
+ --server http://server1:8033,http://server2:8033 \
+ --server-name server1,server2 \
+ --threads 16,16 \
+ --dataset aime2025 --n_cases 240 \
+ --grader-type regex
+```
diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py
new file mode 100755
index 0000000000..b33a3615be
--- /dev/null
+++ b/examples/llama-eval/llama-eval.py
@@ -0,0 +1,1416 @@
+#!/usr/bin/env python3
+# type: ignore
+
+import argparse
+import json
+import os
+import re
+import subprocess
+import sys
+import threading
+import time
+from abc import ABC, abstractmethod
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from dataclasses import dataclass, asdict, field
+from pathlib import Path
+from queue import Queue
+from typing import Dict, List, Optional, Any, Tuple
+import requests
+from tqdm import tqdm
+import random
+from math import sqrt
+
+
+@dataclass
+class ServerConfig:
+ url: str
+ threads: int
+ name: str = ""
+
+def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, float]:
+ """Wilson score confidence interval for a proportion."""
+ if total == 0:
+ return (0.0, 1.0)
+ p = correct / total
+ z2 = z * z / total
+ center = (p + z2 / 2) / (1 + z2)
+ margin = z * sqrt((p * (1 - p) + z2 / 4) / total) / (1 + z2)
+ return (center - margin, center + margin)
+
+cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
+cache_dir.mkdir(parents=True, exist_ok=True)
+os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
+os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
+
+GRADER_PATTERNS = {
+ "aime": r'\boxed{(\d+)}|\b(\d+)\b',
+ "aime2025": r'\boxed{(\d+)}|\b(\d+)\b',
+ "gsm8k": r'\b(\d+)\b',
+}
+
+SAMPLE_ANSWERS = {
+ "aime": [
+ "42",
+ "-123",
+ "999"
+ ],
+ "aime2025": [
+ "42",
+ "-123",
+ "999"
+ ],
+ "gsm8k": [
+ "42",
+ "-123",
+ "999"
+ ],
+ "gpqa": [
+ "A",
+ "D",
+ "C"
+ ],
+}
+
+TEMPLATE_REGISTRY = {
+ "aime": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}.
+
+{question}
+
+Remember to put your answer inside \\boxed{{}}.
+""",
+ "aime2025": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}.
+
+{question}
+
+Remember to put your answer inside \\boxed{{}}.
+""",
+ "gsm8k": """{question}
+Please reason step by step, and put your final numeric answer within \\boxed{{}} without any extra characters.
+""",
+ "gpqa": """Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A').
+
+{Question}
+
+A) {A}
+B) {B}
+C) {C}
+D) {D}
+""",
+}
+
+
+class BaseDataset(ABC):
+ @abstractmethod
+ def get_question(self, index: int) -> Dict:
+ pass
+
+ @abstractmethod
+ def get_question_text(self, question: Dict) -> str:
+ pass
+
+ @abstractmethod
+ def get_answer(self, question: Dict) -> str:
+ pass
+
+ @abstractmethod
+ def get_prompt(self, question: Dict) -> str:
+ pass
+
+ def __len__(self) -> int:
+ return len(self.questions)
+
+
+@dataclass
+class TaskState:
+ task_id: str
+ prompt: str
+ expected: str
+ question_text: str = ""
+ response: Optional[str] = None
+ answer: Optional[str] = None
+ grader_log: Dict[str, Any] = field(default_factory=dict)
+ correct: bool = False
+ status: str = "pending"
+ tokens: Optional[int] = None
+ tps_gen: Optional[float] = None
+ t_gen_ms: Optional[float] = None
+ reasoning_content: Optional[str] = None
+ server_name: Optional[str] = None
+
+
+class EvalState:
+ def __init__(
+ self,
+ dataset_type: str,
+ sampling_config: Dict[str, Any],
+ output_file: Path = Path("llama-eval-state.json"),
+ model_name: Optional[str] = None
+ ):
+ self.dataset_type = dataset_type
+ self.sampling_config = sampling_config
+ self.output_file = output_file
+ self.model_name = model_name
+ self.dataset: Optional[BaseDataset] = None
+ self.tasks: List[Tuple[int, str]] = []
+ self.all_tasks: List[Tuple[int, str]] = []
+ self.task_states: Dict[str, Any] = {}
+ self.total = 0
+ self.correct = 0
+ self.processed = 0
+ self.total_time: float = 0.0
+ self._lock = threading.Lock()
+
+ def load_dataset(self, seed: int = 1234):
+ if self.dataset_type == "aime":
+ self.dataset = AimeDataset()
+ elif self.dataset_type == "aime2025":
+ self.dataset = Aime2025Dataset()
+ elif self.dataset_type == "gsm8k":
+ self.dataset = Gsm8kDataset()
+ elif self.dataset_type == "gpqa":
+ self.dataset = GpqaDataset(variant="diamond", seed=seed)
+ else:
+ raise ValueError(f"Unknown dataset type: {self.dataset_type}")
+
+ def setup_tasks(self, n_cases: Optional[int] = None, seed: int = 1234):
+ if self.dataset is None:
+ raise ValueError("Dataset not loaded. Call load_dataset() first.")
+
+ if n_cases is None:
+ n_cases = len(self.dataset)
+
+ dataset_size = len(self.dataset)
+ rng = random.Random(seed)
+
+ self.tasks = []
+ for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size):
+ chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size)
+ indices = list(range(dataset_size))
+ rng.shuffle(indices)
+ chunk_indices = indices[:chunk_size]
+
+ for i in chunk_indices:
+ task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}"
+ self.tasks.append((i, task_id))
+
+ self.all_tasks = list(self.tasks)
+
+ def get_case(self, index: int) -> Tuple[str, str, str]:
+ if self.dataset is None:
+ raise ValueError("Dataset not loaded.")
+ question = self.dataset.get_question(index)
+ question_text = self.dataset.get_question_text(question)
+ prompt = self.dataset.get_prompt(question)
+ expected = self.dataset.get_answer(question)
+ return question_text, prompt, expected
+
+ def add_result(
+ self,
+ task_id: str,
+ prompt: str,
+ expected: str,
+ response: Optional[str],
+ answer: Optional[str],
+ grader_log: Dict[str, Any],
+ correct: bool,
+ status: str,
+ tokens: Optional[int] = None,
+ tps_gen: Optional[float] = None,
+ t_gen_ms: Optional[float] = None,
+ reasoning_content: Optional[str] = None,
+ server_name: Optional[str] = None
+ ):
+ with self._lock:
+ if "cases" not in self.task_states:
+ self.task_states["cases"] = {}
+
+ self.task_states["cases"][task_id] = {
+ "task_id": task_id,
+ "prompt": prompt,
+ "expected": expected,
+ "response": response,
+ "answer": answer,
+ "grader_log": grader_log,
+ "correct": correct,
+ "status": status,
+ "tokens": tokens,
+ "tps_gen": tps_gen,
+ "t_gen_ms": t_gen_ms,
+ "reasoning_content": reasoning_content,
+ "server_name": server_name
+ }
+
+ self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
+
+ def print_progress(self, task_state: TaskState, total_tasks: int, n_correct: int = 0):
+ display_answer = task_state.answer if task_state.answer else "N/A"
+ display_tokens = str(task_state.tokens) if task_state.tokens is not None else "N/A"
+ display_tps = f"{task_state.tps_gen:.1f}" if task_state.tps_gen is not None else "N/A"
+ display_t_gen = f"{task_state.t_gen_ms/1000:.1f}" if task_state.t_gen_ms is not None else "N/A"
+ display_server = task_state.server_name if task_state.server_name else "N/A"
+ success_ratio = n_correct / self.processed if self.processed > 0 else 0.0
+ first_line = task_state.question_text.split('\n')[0]
+ truncated_question = first_line[:43]
+ if len(first_line) > 43:
+ truncated_question += "..."
+ else:
+ truncated_question = truncated_question.ljust(43) + "..."
+ print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {display_answer:<10} {display_tokens:<6} {display_tps:<6} {display_t_gen:<8} {'✓' if task_state.correct else '✗'} [{n_correct:3}/{self.processed:3}, {success_ratio:.3f}] {display_server}")
+
+ def print_summary(self):
+ if self.total == 0:
+ print(f"\n{'='*60}")
+ print(f"Results: 0/0 correct (0.0%)")
+ print(f"{'='*60}")
+ else:
+ ci_lower, ci_upper = self.accuracy_ci()
+ print(f"\n{'='*60}")
+ print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]")
+ print(f"{'='*60}")
+
+ def dump(self):
+ with self._lock:
+ tasks_to_save = self.all_tasks if self.all_tasks else self.tasks
+ all_cases = {}
+ for i, task_id in tasks_to_save:
+ question_text, prompt, expected = self.get_case(i)
+ if task_id in self.task_states.get("cases", {}):
+ all_cases[task_id] = self.task_states["cases"][task_id]
+ else:
+ all_cases[task_id] = {
+ "task_id": task_id,
+ "prompt": prompt,
+ "expected": expected,
+ "question_text": question_text,
+ "response": None,
+ "answer": None,
+ "grader_log": {},
+ "correct": False,
+ "status": "pending",
+ "tokens": None,
+ "tps_gen": None,
+ "t_gen_ms": None,
+ "reasoning_content": None,
+ "server_name": None
+ }
+
+ ci_lower, ci_upper = self.accuracy_ci()
+ data = {
+ "id": self.dataset_type,
+ "model_name": self.model_name,
+ "tasks": [tid for _, tid in tasks_to_save],
+ "task_states": {
+ "total": self.total,
+ "correct": self.correct,
+ "total_time": self.total_time,
+ "ci_lower": ci_lower,
+ "ci_upper": ci_upper,
+ "cases": all_cases,
+ },
+ "sampling_config": self.sampling_config
+ }
+ with open(self.output_file, "w") as f:
+ json.dump(data, f, indent=2)
+
+ self.dump_html(tasks_to_save, all_cases)
+
+ def dump_html(self, tasks_to_save: List[Tuple[int, str]], all_cases: Dict[str, Any]):
+ html_file = Path(str(self.output_file) + ".html")
+
+ cases = all_cases
+ completed = {tid: c for tid, c in cases.items() if c.get("status") == "ok"}
+ n_correct = sum(1 for c in completed.values() if c.get("correct", False))
+ n_incorrect = len(completed) - n_correct
+ n_pending = len(tasks_to_save) - len(completed)
+ accuracy = n_correct / len(completed) * 100 if completed else 0.0
+ ci_lower, ci_upper = wilson_interval(n_correct, len(completed)) if completed else (0.0, 1.0)
+
+ sampling_parts = []
+ for k, v in self.sampling_config.items():
+ if v is not None:
+ sampling_parts.append(f"{k}={v}")
+ sampling_str = ", ".join(sampling_parts) if sampling_parts else "default"
+
+ rows = []
+ for i, task_id in tasks_to_save:
+ case = cases.get(task_id, {})
+ status = case.get("status", "pending")
+ expected = case.get("expected", "")
+ answer = case.get("answer", "") if status == "ok" else ""
+ is_correct = case.get("correct", False) if status == "ok" else False
+ response = case.get("response", "") or ""
+ prompt = case.get("prompt", "") or ""
+ grader_log = case.get("grader_log", {})
+
+ if status == "ok":
+ status_class = "correct" if is_correct else "incorrect"
+ status_text = "✓" if is_correct else "✗"
+ elif status == "pending":
+ status_class = "pending"
+ status_text = "–"
+ else:
+ status_class = "error"
+ status_text = "!"
+
+ tokens = case.get("tokens")
+ tokens_str = str(tokens) if tokens is not None else ""
+ tps_gen = case.get("tps_gen")
+ tps_str = f"{tps_gen:.1f}" if tps_gen is not None else ""
+ t_gen_ms = case.get("t_gen_ms")
+ t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else ""
+ reasoning_content = case.get("reasoning_content", "") or ""
+ server_name = case.get("server_name", "") or ""
+
+ escaped_response = self._escape_html(response)
+ escaped_prompt = self._escape_html(prompt)
+ escaped_reasoning = self._escape_html(reasoning_content)
+ grader_log_str = self._escape_html(json.dumps(grader_log, indent=2))
+ escaped_server = self._escape_html(server_name)
+
+ rows.append(f"""
+ | {task_id} |
+ {status_text} |
+ {self._escape_html(expected)} |
+ {self._escape_html(answer)} |
+ {tokens_str} |
+ {tps_str} |
+ {t_gen_str} |
+ {escaped_server} |
+
+
+
+
+ Prompt{escaped_prompt}
+ Response{escaped_response}
+ {f' Reasoning{escaped_reasoning}' if escaped_reasoning else ''}
+ Grader{grader_log_str}
+
+ |
+
""")
+
+ rows_html = "\n".join(rows)
+
+ html_content = f"""
+
+
+
+{self.dataset_type.upper()} Eval
+
+
+
+
+ {self.dataset_type.upper()}
+ Model: {self.model_name or 'N/A'}
+ Accuracy: {accuracy:.1f}% [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]
+ Correct: {n_correct} / {len(completed)}
+ Pending: {n_pending}
+ Time: {self.total_time:.1f}s
+ Sampling: {sampling_str}
+
+
+
+
+ | ID |
+ |
+ Gold |
+ Answer |
+ Tokens |
+ T/s |
+ Gen s |
+ Server |
+
+
+
+ {rows_html}
+
+
+
+
+"""
+
+ with open(html_file, "w") as f:
+ f.write(html_content)
+
+ def _escape_html(self, s: str) -> str:
+ return (s.replace("&", "&")
+ .replace("<", "<")
+ .replace(">", ">")
+ .replace('"', """)
+ .replace("'", "'"))
+
+ @classmethod
+ def load(cls, path: Path) -> "EvalState":
+ with open(path, "r") as f:
+ data = json.load(f)
+
+ eval_state = cls(
+ dataset_type=data["id"],
+ sampling_config=data["sampling_config"],
+ output_file=path,
+ model_name=data.get("model_name")
+ )
+ eval_state.load_dataset()
+
+ eval_state.tasks = []
+ eval_state.all_tasks = []
+ for task_id in data.get("tasks", []):
+ parts = task_id.rsplit("_", 2)
+ if len(parts) >= 3:
+ idx = int(parts[-1])
+ else:
+ idx = 0
+ eval_state.tasks.append((idx, task_id))
+ eval_state.all_tasks.append((idx, task_id))
+
+ eval_state.task_states = data.get("task_states", {})
+
+ cases = eval_state.task_states.get("cases", {})
+ eval_state.total = eval_state.task_states.get("total", 0)
+ eval_state.correct = eval_state.task_states.get("correct", 0)
+ eval_state.total_time = eval_state.task_states.get("total_time", 0.0)
+
+ if eval_state.total == 0:
+ eval_state.total = len(cases)
+ eval_state.correct = sum(1 for c in cases.values() if c.get("correct", False))
+
+ return eval_state
+
+ def is_complete(self) -> bool:
+ if not self.all_tasks:
+ return False
+ cases = self.task_states.get("cases", {})
+ completed = {tid for tid in self.task_states.get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"}
+ return len(completed) == len(self.all_tasks)
+
+ def get_pending_tasks(self) -> List[Tuple[int, str]]:
+ cases = self.task_states.get("cases", {})
+ pending = []
+ for i, task_id in self.all_tasks:
+ status = cases.get(task_id, {}).get("status", "pending")
+ if status != "ok":
+ pending.append((i, task_id))
+ return pending
+
+ def print_all_tasks(self):
+ cases = self.task_states.get("cases", {})
+ tasks_to_show = self.all_tasks if self.all_tasks else self.tasks
+ print()
+ print("Tasks:")
+ print(" Task ID Dataset Prompt (first 40 chars) Expected Answer Tokens T/s Gen s Status")
+ for i, task_id in tasks_to_show:
+ question, prompt, expected = self.get_case(i)
+ case = cases.get(task_id, {})
+ status = case.get("status", "pending")
+ answer = case.get("answer", "N/A") if status == "ok" else "N/A"
+ tokens = case.get("tokens")
+ tokens_str = str(tokens) if tokens is not None else "N/A"
+ tps_gen = case.get("tps_gen")
+ tps_str = f"{tps_gen:.1f}" if tps_gen is not None else "N/A"
+ t_gen_ms = case.get("t_gen_ms")
+ t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else "N/A"
+ server_name = case.get("server_name", "") or ""
+ is_correct = case.get("correct", False) if status == "ok" else False
+ symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "")
+ first_line = question.split('\n')[0]
+ question_trunc = first_line[:43]
+ if len(first_line) > 43:
+ question_trunc += "..."
+ else:
+ question_trunc = question_trunc.ljust(43) + "..."
+ print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {tps_str:<6} {t_gen_str:<8} {symbol}{status} {server_name}")
+ print()
+
+ def print_existing_summary(self):
+ cases = self.task_states.get("cases", {})
+ completed_cases = {tid: c for tid, c in cases.items() if c.get("status") == "ok"}
+ correct = sum(1 for c in completed_cases.values() if c.get("correct", False))
+ total = len(completed_cases)
+ if total == 0:
+ print(f"{'='*60}")
+ print(f"Results: 0/0 correct (0.0%)")
+ print(f"{'='*60}")
+ else:
+ ci_lower, ci_upper = self.accuracy_ci()
+ print(f"{'='*60}")
+ print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]")
+ print(f"{'='*60}")
+
+ def accuracy_ci(self) -> Tuple[float, float]:
+ """Compute Wilson score confidence interval from completed cases."""
+ cases = self.task_states.get("cases", {})
+ completed = {tid: c for tid, c in cases.items() if c.get("status") == "ok"}
+ correct = sum(1 for c in completed.values() if c.get("correct", False))
+ total = len(completed)
+ return wilson_interval(correct, total)
+
+def normalize_number(s: str) -> Optional[int]:
+ match = re.match(r"\d+", s) # match digits from the start
+ if not match:
+ return None
+ return int(match.group(0))
+
+class AimeDataset(BaseDataset):
+ def __init__(self, split: str = "train"):
+ self.split = split
+ self.questions: List[Dict] = []
+ self._load_dataset()
+
+ def _load_dataset(self):
+ print(f"Loading AIME dataset (split: {self.split})...")
+ from datasets import load_dataset
+
+ cache_path = cache_dir / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
+ if cache_path.exists():
+ print(f"Using cached dataset from {cache_path}")
+ ds = load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
+ else:
+ ds = load_dataset("AI-MO/aimo-validation-aime", split=self.split)
+
+ self.questions = []
+ for row in ds:
+ question = dict(row)
+ question["dataset_type"] = "aime"
+ self.questions.append(question)
+
+ print(f"AIME dataset loaded: {len(self.questions)} questions")
+
+ def get_question(self, index: int) -> Dict:
+ """Get question by index"""
+ return self.questions[index]
+
+ def get_question_text(self, question: Dict) -> str:
+ """Get question string"""
+ return question["problem"] if "problem" in question else question["question"]
+
+ def get_answer(self, question: Dict) -> str:
+ answer = question["answer"]
+ if isinstance(answer, str):
+ normalized = normalize_number(answer)
+ return str(normalized) if normalized is not None else answer
+ return str(answer)
+
+ def get_prompt(self, question: Dict) -> str:
+ """Get formatted prompt for the question"""
+ return TEMPLATE_REGISTRY[question["dataset_type"]].format(
+ question=self.get_question_text(question),
+ )
+
+class Aime2025Dataset(BaseDataset):
+ def __init__(self):
+ self.questions: List[Dict] = []
+ self._load_dataset()
+
+ def _load_dataset(self):
+ print(f"Loading AIME2025 dataset...")
+ from datasets import load_dataset
+
+ config_name = "AIME2025-I"
+ cache_path = cache_dir / "opencompass___AIME2025" / "default" / "0.0.0"
+ if cache_path.exists():
+ print(f"Using cached dataset from {cache_path}")
+ ds = load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path))
+ else:
+ ds = load_dataset("opencompass/AIME2025", config_name, split="test")
+
+ self.questions = []
+ for row in ds:
+ question = dict(row)
+ question["dataset_type"] = "aime2025"
+ self.questions.append(question)
+
+ print(f"AIME2025 dataset loaded: {len(self.questions)} questions")
+
+ print(f"Loading AIME2025 dataset (part 2)...")
+ config_name_2 = "AIME2025-II"
+ cache_path_2 = cache_dir / "opencompass___AIME2025" / "default" / "0.0.0"
+ if cache_path_2.exists():
+ print(f"Using cached dataset from {cache_path_2}")
+ ds_2 = load_dataset("opencompass/AIME2025", config_name_2, split="test", cache_dir=str(cache_path_2))
+ else:
+ ds_2 = load_dataset("opencompass/AIME2025", config_name_2, split="test")
+
+ for row in ds_2:
+ question = dict(row)
+ question["dataset_type"] = "aime2025"
+ self.questions.append(question)
+
+ print(f"AIME2025 dataset loaded: {len(self.questions)} questions (total)")
+
+ def get_question(self, index: int) -> Dict:
+ """Get question by index"""
+ return self.questions[index]
+
+ def get_question_text(self, question: Dict) -> str:
+ """Get question string"""
+ return question["question"]
+
+ def get_answer(self, question: Dict) -> str:
+ answer = question["answer"]
+ if isinstance(answer, str):
+ normalized = normalize_number(answer)
+ return str(normalized) if normalized is not None else answer
+ return str(answer)
+
+ def get_prompt(self, question: Dict) -> str:
+ """Get formatted prompt for the question"""
+ return TEMPLATE_REGISTRY["aime2025"].format(
+ question=self.get_question_text(question),
+ )
+
+class Gsm8kDataset(BaseDataset):
+ def __init__(self, split: str = "test"):
+ self.split = split
+ self.questions: List[Dict] = []
+ self._load_dataset()
+
+ def _load_dataset(self):
+ print(f"Loading GSM8K dataset (split: {self.split})...")
+ from datasets import load_dataset
+
+ cache_path = cache_dir / "openai___gsm8k" / "default" / "0.0.0"
+ if cache_path.exists():
+ print(f"Using cached dataset from {cache_path}")
+ ds = load_dataset("openai/gsm8k", "main", split=self.split, cache_dir=str(cache_path))
+ else:
+ ds = load_dataset("openai/gsm8k", "main", split=self.split)
+
+ self.questions = []
+ for row in ds:
+ question = dict(row)
+ question["dataset_type"] = "gsm8k"
+
+ # Extract numeric answer from the answer field (already has #### prefix)
+ gold = question["answer"]
+ # Split by #### and take the last part
+ parts = gold.split("####")
+ if len(parts) > 1:
+ gold = parts[-1].strip()
+ # Extract the first number from the remaining text
+ normalized = normalize_number(gold)
+ question["gold"] = str(normalized) if normalized is not None else gold
+
+ self.questions.append(question)
+
+ print(f"GSM8K dataset loaded: {len(self.questions)} questions")
+
+ def get_question(self, index: int) -> Dict:
+ """Get question by index"""
+ return self.questions[index]
+
+ def get_question_text(self, question: Dict) -> str:
+ """Get question string"""
+ return question["problem"] if "problem" in question else question["question"]
+
+ def get_answer(self, question: Dict) -> str:
+ # GSM8K has pre-extracted gold field, AIME uses answer field
+ if "gold" in question:
+ return question["gold"]
+ answer = question["answer"]
+ if isinstance(answer, str):
+ normalized = normalize_number(answer)
+ return str(normalized) if normalized is not None else answer
+ return str(answer)
+
+ def get_prompt(self, question: Dict) -> str:
+ """Get formatted prompt for the question"""
+ return TEMPLATE_REGISTRY[question["dataset_type"]].format(
+ question=self.get_question_text(question),
+ )
+
+class GpqaDataset(BaseDataset):
+ def __init__(self, variant: str = "diamond", seed: int = 1234):
+ self.variant = variant
+ self.seed = seed
+ self.questions: List[Dict] = []
+ self._load_dataset()
+
+ def _load_dataset(self):
+ print(f"Loading GPQA dataset (variant: {self.variant})...")
+ import pandas as pd
+
+ url = f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{self.variant}.csv"
+ df = pd.read_csv(url)
+
+ rng = random.Random(self.seed)
+
+ self.questions = []
+ for _, row in df.iterrows():
+ question = row.to_dict()
+ question["dataset_type"] = "gpqa"
+
+ # Shuffle the answer options
+ correct_answer = question["Correct Answer"]
+ incorrect_answers = [
+ question["Incorrect Answer 1"],
+ question["Incorrect Answer 2"],
+ question["Incorrect Answer 3"]
+ ]
+
+ # Create list of (answer, is_correct) tuples
+ options = [(ans, ans == correct_answer) for ans in incorrect_answers]
+ options.append((correct_answer, True))
+
+ # Shuffle the options
+ rng.shuffle(options)
+
+ # Extract shuffled answers and determine correct letter
+ shuffled_answers = [ans for ans, _ in options]
+ correct_letter = chr(ord('A') + options.index((correct_answer, True)))
+
+ # Store shuffled answers and correct letter
+ question["shuffled_answers"] = shuffled_answers
+ question["correct_letter"] = correct_letter
+
+ self.questions.append(question)
+
+ print(f"GPQA dataset loaded: {len(self.questions)} questions")
+
+ def get_question(self, index: int) -> Dict:
+ """Get question by index"""
+ return self.questions[index]
+
+ def get_question_text(self, question: Dict) -> str:
+ """Get question string"""
+ return question["Question"]
+
+ def get_answer(self, question: Dict) -> str:
+ # GPQA returns the correct letter (A, B, C, or D)
+ return question["correct_letter"]
+
+ def get_prompt(self, question: Dict) -> str:
+ """Get formatted prompt for the question"""
+ return TEMPLATE_REGISTRY["gpqa"].format(
+ Question=self.get_question_text(question),
+ A=question["shuffled_answers"][0],
+ B=question["shuffled_answers"][1],
+ C=question["shuffled_answers"][2],
+ D=question["shuffled_answers"][3]
+ )
+
+class Grader:
+ def __init__(
+ self,
+ grader_type: str = "llm",
+ grader_script: Optional[str] = None,
+ grader_model_name: Optional[str] = None,
+ grader_server_url: str = "",
+ dataset_type: str = "aime"
+ ):
+ self.grader_type = grader_type
+ self.grader_script = grader_script
+ self.grader_model_name = grader_model_name
+ self.grader_server_url = grader_server_url
+ self.dataset_type = dataset_type
+ self.pattern = self._get_pattern()
+
+ def _get_pattern(self) -> Optional[str]:
+ if self.grader_type == "regex":
+ return GRADER_PATTERNS.get(self.dataset_type) # Use dataset_type as key
+ return None
+
+ def _extract_answer_regex(self, pred: str) -> Optional[str]:
+ """Extract answer using regex pattern"""
+ if not self.pattern:
+ return None
+
+ # For AIME datasets, prioritize boxed answers
+ if self.dataset_type in ["aime", "aime2025"]:
+ boxed_pattern = r'\\boxed{([^}]+)}'
+ boxed_matches = re.findall(boxed_pattern, pred, re.IGNORECASE)
+ if boxed_matches:
+ # Return the last boxed answer found (most likely the final answer)
+ return boxed_matches[-1].strip()
+
+ # For other datasets, search for numbers from the end of the text
+ # This prioritizes numbers that appear later in the response
+ matches = re.findall(self.pattern, pred, re.IGNORECASE)
+ if not matches:
+ return None
+
+ # Process matches from end to start
+ for match in reversed(matches):
+ if isinstance(match, tuple):
+ match = match[0] if match[0] else match[1]
+ answer = match.strip()
+ if answer:
+ return answer
+ return None
+
+ def _grade_regex(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
+ """Grade using regex pattern matching"""
+ answer = self._extract_answer_regex(pred)
+ if answer is None:
+ return False, None
+ is_correct = answer.strip() == gold.strip()
+ return is_correct, answer
+
+ def _grade_cli(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
+ """Grade using external CLI script"""
+ if not self.grader_script:
+ raise ValueError("CLI grader requires --grader-script")
+
+ script_path = Path(self.grader_script)
+ if not script_path.exists():
+ raise FileNotFoundError(f"Grader script not found: {self.grader_script}")
+
+ try:
+ result = subprocess.run(
+ [str(script_path), "--answer", pred, "--expected", gold],
+ capture_output=True,
+ text=True,
+ timeout=30
+ )
+ is_correct = result.returncode == 0
+ answer = pred if is_correct else None
+ return is_correct, answer
+ except subprocess.TimeoutExpired:
+ return False, None
+ except Exception as e:
+ return False, None
+
+ def _grade_llm(self, gold: str, pred: str, problem: str) -> Tuple[bool, Optional[str]]:
+ """Grade using LLM-based extraction with few-shot examples"""
+ sample_answers = SAMPLE_ANSWERS.get(self.dataset_type, [])
+ sample_examples = "\n".join([
+ f"Example {i+1}: {ans}" for i, ans in enumerate(sample_answers)
+ ])
+
+ system_prompt = f"""You are an answer extraction system. Your task is to extract the answer from the model's response.
+
+Here are some examples of extracted answers to demonstrate what you are supposed to output:
+
+{sample_examples}
+
+When extracting the answer, provide only the extracted answer itself, nothing else. If there is no clear answer that can be extracted from the response, reply with 'no answer'."""
+
+ user_prompt = f"""Extract the answer from the following response:
+
+"{pred}"
+
+Please provide only the extracted answer, nothing else. If there is no clear answer that can be extracted from the response, reply with 'no answer'."""
+
+ url = f"{self.grader_server_url}/v1/chat/completions"
+ headers = {"Content-Type": "application/json"}
+ data = {
+ "model": self.grader_model_name,
+ "messages": [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt}
+ ],
+ "temperature": 0,
+ }
+ #print(json.dumps(data, indent=2))
+
+ try:
+ response = requests.post(url, headers=headers, json=data)
+ response.raise_for_status()
+ answer = response.json()["choices"][0]["message"]["content"].strip()
+ is_correct = answer.strip().lower() == gold.strip().lower()
+ return is_correct, answer
+ except Exception as e:
+ return False, None
+
+ def _truncate_response(self, response: str, max_lines: int = 6) -> str:
+ """Keep only last N lines of response"""
+ lines = response.split('\n')
+ return '\n'.join(lines[-max_lines:]) if len(lines) > max_lines else response
+
+ def grade(self, gold: str, pred: str, problem: str = "") -> Tuple[bool, Optional[str]]:
+ """Grade the response"""
+ if self.grader_type == "regex":
+ return self._grade_regex(gold, pred)
+ elif self.grader_type == "cli":
+ return self._grade_cli(gold, pred)
+ elif self.grader_type == "llm":
+ return self._grade_llm(gold, pred, problem)
+ else:
+ raise ValueError(f"Unknown grader type: {self.grader_type}")
+
+class Processor:
+ def __init__(
+ self,
+ server_configs: List[ServerConfig],
+ grader: Grader,
+ model_name: Optional[str] = None,
+ n_predict: int = -1
+ ):
+ self.server_configs = server_configs
+ self.grader = grader
+ self.model_name = model_name
+ self.n_predict = n_predict
+
+ @staticmethod
+ def _check_server(server_config: ServerConfig) -> List[str]:
+ url = f"{server_config.url}/v1/models"
+ try:
+ response = requests.get(url)
+ response.raise_for_status()
+ models = [m["id"] for m in response.json().get("data", [])]
+ return models
+ except Exception as e:
+ print(f"Error: Cannot reach server {server_config.name} ({server_config.url}): {e}", file=sys.stderr)
+ sys.exit(1)
+
+ def _make_request(
+ self, server_config: ServerConfig, eval_state: EvalState, prompt: str
+ ) -> Tuple[Dict[str, Any], int, Optional[float], Optional[float], str]:
+ url = f"{server_config.url}/v1/chat/completions"
+ headers = {"Content-Type": "application/json"}
+ data = {
+ "model": self.model_name if self.model_name else "llama",
+ "messages": [{"role": "user", "content": prompt}],
+ "n_predict": self.n_predict
+ }
+ if eval_state.sampling_config.get("temperature") is not None:
+ data["temperature"] = eval_state.sampling_config["temperature"]
+ if eval_state.sampling_config.get("top_k") is not None:
+ data["top_k"] = eval_state.sampling_config["top_k"]
+ if eval_state.sampling_config.get("top_p") is not None:
+ data["top_p"] = eval_state.sampling_config["top_p"]
+ if eval_state.sampling_config.get("min_p") is not None:
+ data["min_p"] = eval_state.sampling_config["min_p"]
+
+ response = requests.post(url, headers=headers, json=data)
+ response.raise_for_status()
+ result = response.json()
+ tokens = result.get("usage", {}).get("completion_tokens", 0)
+ timings = result.get("timings", {})
+ tps_gen = timings.get("predicted_per_second") if timings else None
+ t_gen_ms = timings.get("predicted_ms") if timings else None
+ finish_reason = result.get("choices", [{}])[0].get("finish_reason", "stop")
+ return result, tokens, tps_gen, t_gen_ms, finish_reason
+
+ def _process_single_case(
+ self, server_config: ServerConfig, eval_state: EvalState, i: int, task_id: str
+ ) -> TaskState:
+ question_text, prompt, expected = eval_state.get_case(i)
+
+ task_state = TaskState(
+ task_id=task_id,
+ prompt=prompt,
+ expected=expected,
+ question_text=question_text,
+ server_name=server_config.name
+ )
+
+ try:
+ response, tokens, tps_gen, t_gen_ms, finish_reason = self._make_request(server_config, eval_state, prompt)
+ result = response["choices"][0]["message"]["content"]
+ reasoning_content = response["choices"][0].get("message", {}).get("reasoning_content")
+ task_state.response = result
+ task_state.tokens = tokens
+ task_state.tps_gen = tps_gen
+ task_state.t_gen_ms = t_gen_ms
+ task_state.reasoning_content = reasoning_content
+
+ if finish_reason != "stop":
+ task_state.status = f"error: finish_reason={finish_reason}"
+ eval_state.add_result(
+ task_id, prompt, expected, result, None,
+ {"finish_reason": finish_reason}, False, task_state.status,
+ tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
+ )
+ eval_state.dump()
+ return task_state
+
+ result_truncated = self.grader._truncate_response(result, max_lines=10)
+ is_correct, answer = self.grader.grade(expected, result_truncated, prompt)
+
+ grader_log = {
+ "pred": result_truncated,
+ "grader_type": self.grader.grader_type
+ }
+ if self.grader.grader_type == "regex" and self.grader.pattern:
+ grader_log["pattern"] = self.grader.pattern
+
+ task_state.correct = is_correct
+ task_state.answer = answer
+ task_state.grader_log = grader_log
+ task_state.status = "ok"
+
+ eval_state.add_result(
+ task_id, prompt, expected, result, answer,
+ grader_log, is_correct, "ok",
+ tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
+ )
+
+ eval_state.dump()
+
+ except Exception as e:
+ task_state.status = f"error: {str(e)}"
+
+ return task_state
+
+ @staticmethod
+ def _worker(
+ server_config: ServerConfig,
+ processor: "Processor",
+ eval_state: EvalState,
+ task_queue: Queue,
+ results_queue: Queue,
+ ):
+ """Worker that pulls tasks from a shared queue and sends them to its server."""
+ while True:
+ task = task_queue.get()
+ if task is None: # sentinel
+ task_queue.task_done()
+ break
+ try:
+ i, task_id = task
+ result = processor._process_single_case(server_config, eval_state, i, task_id)
+ results_queue.put(result)
+ finally:
+ task_queue.task_done()
+
+ def evaluate(self, eval_state: EvalState, verbose: bool = False, resume: bool = False):
+ total_tasks = len(eval_state.tasks)
+ eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks
+ eval_state.processed = 0
+ start_time = time.time()
+
+ # Check servers and list models
+ server_models = [self._check_server(sc) for sc in self.server_configs]
+
+ # Print server info
+ print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} tasks ...")
+ print(f"Servers ({len(self.server_configs)}):")
+ for i, sc in enumerate(self.server_configs):
+ models_str = ", ".join(server_models[i]) if server_models[i] else "(none)"
+ print(f" {i+1}. {sc.name} — {sc.url} ({sc.threads} threads) [{models_str}]")
+ print(f"Model: {self.model_name}")
+ print(f"Grader: {self.grader.grader_type}")
+ print(f"Sampling: temp={eval_state.sampling_config.get('temperature', 'skip')}, top-k={eval_state.sampling_config.get('top_k', 'skip')}, top-p={eval_state.sampling_config.get('top_p', 'skip')}, min-p={eval_state.sampling_config.get('min_p', 'skip')}")
+ print()
+
+ # Shared task queue: all workers compete for tasks
+ task_queue: Queue = Queue()
+ for i, task_id in eval_state.tasks:
+ task_queue.put((i, task_id))
+
+ # Results queue: workers push completed TaskStates here
+ results_queue: Queue = Queue()
+
+ # Total worker threads across all servers
+ total_threads = sum(sc.threads for sc in self.server_configs)
+
+ # Add one sentinel per worker so every worker exits cleanly
+ for _ in range(total_threads):
+ task_queue.put(None)
+
+ # Launch workers: one ThreadPoolExecutor per server
+ executors: List[ThreadPoolExecutor] = []
+ worker_futures: List[Any] = []
+ for server_config in self.server_configs:
+ executor = ThreadPoolExecutor(max_workers=server_config.threads)
+ executors.append(executor)
+ for _ in range(server_config.threads):
+ future = executor.submit(
+ self._worker, server_config, self, eval_state,
+ task_queue, results_queue
+ )
+ worker_futures.append(future)
+
+ # Drain results as they complete
+ n_correct = 0
+ session_time = 0.0
+ completed_count = 0
+
+ while completed_count < total_tasks:
+ task_state = results_queue.get()
+ eval_state.processed += 1
+ completed_count += 1
+ if task_state.correct:
+ n_correct += 1
+ elapsed = time.time() - start_time
+ eval_state.total_time += elapsed
+ session_time += elapsed
+ start_time = time.time()
+ eval_state.print_progress(task_state, total_tasks, n_correct)
+
+ if verbose:
+ print(f"\nCase {eval_state.processed}: {task_state.correct}")
+ print(f" Expected: {task_state.expected}")
+ if task_state.response:
+ print(f" Response: {task_state.response}")
+ if task_state.answer:
+ print(f" Answer: {task_state.answer}")
+ print(f" Status: {task_state.status}")
+
+ # Wait for all workers to finish and shut down executors
+ for future in worker_futures:
+ future.result()
+ for executor in executors:
+ executor.shutdown(wait=True)
+
+ print(f"\nSession time: {session_time:.1f}s | Total accumulated time: {eval_state.total_time:.1f}s")
+ eval_state.print_summary()
+ eval_state.dump()
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Simplified evaluation tool for llama.cpp"
+ )
+ parser.add_argument(
+ "--server",
+ type=str,
+ default="http://localhost:8033",
+ help="Comma-separated llama-server URLs (default: http://localhost:8033)"
+ )
+ parser.add_argument(
+ "--server-name",
+ type=str,
+ default="",
+ help="Comma-separated display names for servers (default: use URLs)"
+ )
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="aime",
+ choices=["aime", "aime2025", "gsm8k", "gpqa"],
+ help="Dataset type (default: aime)"
+ )
+ parser.add_argument(
+ "--n_cases",
+ type=int,
+ default=None,
+ help="Number of cases to evaluate (default: all)"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=1234,
+ help="Random seed for shuffling (default: 1234)"
+ )
+ parser.add_argument(
+ "--n_predict",
+ type=int,
+ default=-1,
+ help="Max tokens to predict per prompt (default: -1, infinite)"
+ )
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ default=None,
+ help="Sampling temperature (default: not passed)"
+ )
+ parser.add_argument(
+ "--top-k",
+ type=int,
+ default=None,
+ help="Top K sampling (default: not passed)"
+ )
+ parser.add_argument(
+ "--top-p",
+ type=float,
+ default=None,
+ help="Top P sampling (default: not passed)"
+ )
+ parser.add_argument(
+ "--min-p",
+ type=float,
+ default=None,
+ help="Min P sampling (default: not passed)"
+ )
+ parser.add_argument(
+ "--threads",
+ type=str,
+ default="32",
+ help="Comma-separated thread counts per server (default: 32)"
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default=None,
+ help="Model name to append as query parameter (e.g., gpt-oss-20b-hf)"
+ )
+ parser.add_argument(
+ "--verbose",
+ action="store_true",
+ help="Show detailed output for each case"
+ )
+ parser.add_argument(
+ "--output",
+ type=Path,
+ default=Path("llama-eval-state.json"),
+ help="Output file for eval state (default: llama-eval-state.json)"
+ )
+ parser.add_argument(
+ "--grader-type",
+ type=str,
+ default="llm",
+ choices=["regex", "cli", "llm"],
+ help="Grader type: regex, cli, or llm (default: llm)"
+ )
+ parser.add_argument(
+ "--grader-script",
+ type=str,
+ default=None,
+ help="CLI grader script path (required for --grader-type cli)"
+ )
+ parser.add_argument(
+ "--grader-server",
+ type=str,
+ default="",
+ help="Server URL for LLM grader (default: same as main server)"
+ )
+ parser.add_argument(
+ "--grader-model",
+ type=str,
+ default="",
+ help="Model name for LLM grader (default: same as main model)"
+ )
+ parser.add_argument(
+ "--resume",
+ action="store_true",
+ help="Resume from existing eval state"
+ )
+
+ args = parser.parse_args()
+
+ # Parse server URLs and thread counts
+ server_urls = [u.strip() for u in args.server.split(",") if u.strip()]
+ thread_counts = [int(t.strip()) for t in args.threads.split(",") if t.strip()]
+
+ if len(server_urls) != len(thread_counts):
+ print(f"Error: --server ({len(server_urls)} URLs) and --threads ({len(thread_counts)} values) must have the same count")
+ sys.exit(1)
+
+ # Parse server names (optional, defaults to URLs)
+ if args.server_name:
+ server_names = [n.strip() for n in args.server_name.split(",") if n.strip()]
+ if len(server_names) != len(server_urls):
+ print(f"Error: --server-name ({len(server_names)} names) and --server ({len(server_urls)} URLs) must have the same count")
+ sys.exit(1)
+ else:
+ server_names = server_urls # fallback to URLs
+
+ server_configs = [
+ ServerConfig(url=url, threads=threads, name=name)
+ for url, threads, name in zip(server_urls, thread_counts, server_names)
+ ]
+
+ if args.dataset == "gpqa" and args.grader_type != "llm":
+ print("Error: GPQA dataset requires --grader-type llm")
+ parser.print_help()
+ sys.exit(1)
+
+ if args.output.exists():
+ print(f"Loading existing eval state from {args.output}")
+ eval_state = EvalState.load(args.output)
+
+ # Verify model matches
+ if eval_state.model_name is not None and args.model != eval_state.model_name:
+ print(f"Error: Model mismatch. State has '{eval_state.model_name}', but --model is '{args.model}'")
+ sys.exit(1)
+
+ eval_state.print_all_tasks()
+ eval_state.print_existing_summary()
+
+ if eval_state.is_complete():
+ return
+
+ print()
+
+ if not args.resume:
+ print(f"Evaluation incomplete. Run with --resume to continue.")
+ return
+
+ pending_tasks = eval_state.get_pending_tasks()
+ print(f"Resuming from {len(pending_tasks)} pending tasks")
+
+ existing_cases = eval_state.task_states.get("cases", {})
+
+ eval_state.tasks = pending_tasks
+ eval_state.task_states["cases"] = existing_cases
+
+ grader_server_url = args.grader_server if args.grader_server else server_configs[0].url
+ grader_model_name = args.grader_model if args.grader_model else args.model
+ if args.grader_type == "llm" and not grader_model_name:
+ print("Error: --grader-type llm requires --grader-model or --model")
+ sys.exit(1)
+ grader = Grader(
+ grader_type=args.grader_type,
+ grader_script=args.grader_script,
+ grader_model_name=grader_model_name,
+ grader_server_url=grader_server_url,
+ dataset_type=eval_state.dataset_type
+ )
+ resume = True
+ else:
+ if args.resume:
+ print("Error: No existing eval state found to resume")
+ sys.exit(1)
+
+ grader_server_url = args.grader_server if args.grader_server else server_configs[0].url
+ grader_model_name = args.grader_model if args.grader_model else args.model
+ if args.grader_type == "llm" and not grader_model_name:
+ print("Error: --grader-type llm requires --grader-model or --model")
+ sys.exit(1)
+
+ grader = Grader(
+ grader_type=args.grader_type,
+ grader_script=args.grader_script,
+ grader_model_name=grader_model_name,
+ grader_server_url=grader_server_url,
+ dataset_type=args.dataset
+ )
+
+ if args.grader_type == "llm" and not args.grader_server:
+ print("Warning: Using same server for LLM grader (no --grader-server specified)")
+
+ sampling_config = {}
+ if args.temperature is not None:
+ sampling_config["temperature"] = args.temperature
+ if args.top_k is not None:
+ sampling_config["top_k"] = args.top_k
+ if args.top_p is not None:
+ sampling_config["top_p"] = args.top_p
+ if args.min_p is not None:
+ sampling_config["min_p"] = args.min_p
+
+ eval_state = EvalState(
+ dataset_type=args.dataset,
+ sampling_config=sampling_config,
+ output_file=args.output,
+ model_name=args.model
+ )
+ eval_state.load_dataset(seed=args.seed)
+ eval_state.setup_tasks(n_cases=args.n_cases, seed=args.seed)
+ eval_state.dump()
+ resume = False
+
+ eval_state.print_all_tasks()
+
+ processor = Processor(
+ server_configs=server_configs,
+ grader=grader,
+ model_name=args.model,
+ n_predict=args.n_predict
+ )
+
+ processor.evaluate(eval_state, verbose=args.verbose, resume=resume)
+ print(f"\nEval state dumped to {args.output}")
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/llama-eval/llama-server-simulator.py b/examples/llama-eval/llama-server-simulator.py
new file mode 100755
index 0000000000..2f9cdc5450
--- /dev/null
+++ b/examples/llama-eval/llama-server-simulator.py
@@ -0,0 +1,317 @@
+#!/usr/bin/env python3
+
+import argparse
+import json
+import random
+import re
+import time
+import sys
+import os
+import threading
+from http.server import HTTPServer, BaseHTTPRequestHandler
+from typing import Dict, List, Optional
+from dataclasses import dataclass
+from pathlib import Path
+
+import datasets
+
+# Set cache directory for HuggingFace datasets
+cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
+cache_dir.mkdir(parents=True, exist_ok=True)
+os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
+
+def dice(s1: str, s2: str) -> float:
+ """Calculate Dice coefficient between two strings based on bigram overlap."""
+ if not s1 and not s2:
+ return 1.0
+
+ def _bigrams(s: str):
+ return [s[i : i + 2] for i in range(len(s) - 1)]
+
+ bigrams1 = _bigrams(s1)
+ bigrams2 = _bigrams(s2)
+
+ if not bigrams1 and not bigrams2:
+ return 1.0
+
+ from collections import Counter
+
+ freq1 = Counter(bigrams1)
+ freq2 = Counter(bigrams2)
+
+ intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1)
+ dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2))
+ return dice_coeff
+
+def debug_log(message: str):
+ """Log debug messages to both stdout and a file"""
+ print(message, file=sys.stderr)
+ with open("/tmp/simulator-debug.log", "a") as f:
+ f.write(message + "\n")
+
+simulator: Optional["Simulator"] = None
+
+@dataclass
+class EvalState:
+ id: str
+ tasks: List[str]
+ task_states: Dict[str, Dict]
+ sampling_config: Dict
+
+def normalize_number(s: str) -> Optional[int]:
+ match = re.match(r"\d+", s) # match digits from the start
+ if not match:
+ return None
+ return int(match.group(0))
+
+class AimeDataset:
+ def __init__(self, split: str = "train"):
+ self.split = split
+ self.questions: List[Dict] = []
+ self._load_dataset()
+
+ def _load_dataset(self):
+ print(f"Loading AIME dataset (split: {self.split})...")
+
+ cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
+ if cache_path.exists():
+ print(f"Using cached dataset from {cache_path}")
+ ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
+ else:
+ ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
+
+ self.questions = list(ds)
+ print(f"AIME dataset loaded: {len(self.questions)} questions")
+
+ def find_question(self, request_text: str) -> Optional[Dict]:
+ best_match = None
+ best_distance = -1
+ best_index = -1
+
+ for i, question in enumerate(self.questions):
+ question_text = question["problem"]
+ request_lower = request_text.lower()
+ question_lower = question_text.lower()
+
+ # Exact match
+ if question_lower == request_lower:
+ debug_log(f"DEBUG: Found exact match at index {i}")
+ return question
+
+ # Remove LaTeX formatting for more flexible matching
+ question_no_latex = re.sub(r'\$[^$]+\$', '', question_text)
+ if question_no_latex.lower() == request_lower:
+ debug_log(f"DEBUG: Found match (no LaTeX) at index {i}")
+ return question
+
+ # Calculate Dice coefficient for partial matches
+ # Only consider if request is at least 50% of question length
+ if len(request_lower) >= len(question_lower) * 0.5:
+ distance = dice(question_lower, request_lower)
+
+ if distance > best_distance:
+ best_distance = distance
+ best_match = question
+ best_index = i
+
+ if best_match and best_distance > 0.3: # Threshold for partial match
+ debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}")
+ return best_match
+
+ debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...")
+ return None
+
+ def get_answer(self, question: Dict) -> str:
+ answer = question["answer"]
+ if isinstance(answer, str):
+ normalized = normalize_number(answer)
+ return str(normalized) if normalized is not None else answer
+ return str(answer)
+
+class Simulator:
+ def __init__(
+ self,
+ port: int = 8033,
+ host: str = "localhost",
+ success_rate: float = 0.8,
+ dataset_split: str = "train"
+ ):
+ self.port = port
+ self.host = host
+ self.success_rate = success_rate
+ self.dataset = AimeDataset(dataset_split)
+ self.eval_state = EvalState(
+ id="aime-2025",
+ tasks=["aime"],
+ task_states={},
+ sampling_config={"temperature": 0, "max_tokens": 2048}
+ )
+
+ def _generate_response(
+ self,
+ question: Dict,
+ should_be_correct: bool
+ ) -> Dict:
+ expected_answer = self.dataset.get_answer(question)
+
+ if should_be_correct:
+ response_text = expected_answer
+ else:
+ response_text = self._generate_wrong_answer(question)
+
+ return {
+ "id": f"chatcmpl-{int(time.time())}",
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": "llama",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": response_text
+ },
+ "finish_reason": "stop"
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 100,
+ "completion_tokens": 50,
+ "total_tokens": 150
+ }
+ }
+
+ def _generate_wrong_answer(self, question: Dict) -> str:
+ expected_answer = self.dataset.get_answer(question)
+
+ if expected_answer.isdigit():
+ wrong_answer = str(int(expected_answer) + 1)
+ else:
+ wrong_answer = expected_answer + " (wrong)"
+
+ return wrong_answer
+
+ def _process_request(self, request_data: Dict) -> Dict:
+ messages = request_data.get("messages", [])
+ if not messages:
+ return {"error": "No messages in request"}
+
+ request_text = messages[0].get("content", "")
+ debug_log(f"DEBUG: Received request with content: {request_text[:150]}...")
+
+ question = self.dataset.find_question(request_text)
+ if not question:
+ debug_log(f"DEBUG: find_question returned None")
+ return {"error": "No matching question found"}
+
+ should_be_correct = random.random() < self.success_rate
+
+ response = self._generate_response(question, should_be_correct)
+
+ task_id = "aime"
+ self.eval_state.task_states[task_id] = {
+ "correct": should_be_correct,
+ "expected": self.dataset.get_answer(question),
+ "predicted": response["choices"][0]["message"]["content"]
+ }
+
+ return response
+
+class RequestHandler(BaseHTTPRequestHandler):
+ def do_POST(self):
+ if self.path != "/v1/chat/completions":
+ self._send_json({"error": "Not found"}, 404)
+ return
+
+ try:
+ content_length = int(self.headers.get("Content-Length", 0))
+ body = self.rfile.read(content_length)
+ request_data = json.loads(body) if body else None
+
+ if not request_data:
+ self._send_json({"error": "Invalid JSON"}, 400)
+ return
+
+ if simulator is None:
+ self._send_json({"error": "Simulator not initialized"}, 500)
+ return
+
+ response = simulator._process_request(request_data)
+ self._send_json(response, 200)
+
+ except json.JSONDecodeError:
+ self._send_json({"error": "Invalid JSON"}, 400)
+ except Exception as e:
+ print(f"Error processing request: {e}")
+ self._send_json({"error": str(e)}, 500)
+
+ def _send_json(self, data: dict, status: int = 200):
+ body = json.dumps(data).encode("utf-8")
+ self.send_response(status)
+ self.send_header("Content-Type", "application/json")
+ self.send_header("Content-Length", str(len(body)))
+ self.end_headers()
+ self.wfile.write(body)
+
+ def log_message(self, format, *args):
+ # Suppress default request logging
+ pass
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="llama-server simulator for testing eval scripts"
+ )
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=8033,
+ help="Server port (default: 8033)"
+ )
+ parser.add_argument(
+ "--host",
+ type=str,
+ default="localhost",
+ help="Server host (default: localhost)"
+ )
+ parser.add_argument(
+ "--success-rate",
+ type=float,
+ default=0.8,
+ help="Success rate 0-1 (default: 0.8)"
+ )
+ parser.add_argument(
+ "--dataset-split",
+ type=str,
+ default="train",
+ help="AIME dataset split to use (default: train)"
+ )
+
+ args = parser.parse_args()
+
+ global simulator
+ simulator = Simulator(
+ port=args.port,
+ host=args.host,
+ success_rate=args.success_rate,
+ dataset_split=args.dataset_split
+ )
+
+ server = HTTPServer((args.host, args.port), RequestHandler)
+ server_thread = threading.Thread(target=server.serve_forever, daemon=True)
+ server_thread.start()
+
+ print("\n=== llama-server-simulator ===")
+ print(f"Server running on http://{args.host}:{args.port}")
+ print(f"Success rate: {args.success_rate}")
+ print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions")
+ print("\nPress Ctrl+C to stop\n")
+
+ try:
+ server_thread.join()
+ except KeyboardInterrupt:
+ print("\nShutting down...")
+ server.shutdown()
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/llama-eval/test-simulator.sh b/examples/llama-eval/test-simulator.sh
new file mode 100755
index 0000000000..f3ddf3e95d
--- /dev/null
+++ b/examples/llama-eval/test-simulator.sh
@@ -0,0 +1,86 @@
+#!/bin/bash
+
+set -e
+
+# Get the directory where this script is located
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+echo "=== llama-server-simulator Test Script ==="
+echo ""
+
+PORT=8033
+SUCCESS_RATE=0.8
+TEST_PORT=8034
+
+echo "Starting simulator on port $PORT with success rate $SUCCESS_RATE..."
+source "$SCRIPT_DIR/venv/bin/activate"
+python3 "$SCRIPT_DIR/llama-server-simulator.py" --port $PORT --success-rate $SUCCESS_RATE > /tmp/simulator-test.log 2>&1 &
+SIMULATOR_PID=$!
+
+echo "Waiting for simulator to start..."
+sleep 5
+
+# Helper function to make a request and extract the answer
+make_request() {
+ local question="$1"
+ curl -s -X POST http://localhost:$PORT/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d "{
+ \"model\": \"llama\",
+ \"messages\": [
+ {\"role\": \"user\", \"content\": \"$question\"}
+ ],
+ \"temperature\": 0,
+ \"max_tokens\": 2048
+ }" | python3 -c "import sys, json; data = json.load(sys.stdin); print(data.get('choices', [{}])[0].get('message', {}).get('content', data.get('error', 'No response')))"
+}
+
+# Test question (repeated in multiple tests)
+TEST_QUESTION="Quadratic polynomials P(x) and Q(x) have leading coefficients 2 and -2, respectively. The graphs of both polynomials pass through the two points (16,54) and (20,53). Find P(0) + Q(0)."
+
+echo ""
+echo "=== Test 1: Correct Answer ==="
+echo "Sending request with known question..."
+answer=$(make_request "$TEST_QUESTION")
+echo "Answer: $answer"
+echo "Expected: 116"
+echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")"
+
+echo ""
+echo "=== Test 2: Wrong Answer ==="
+echo "Sending request with known question (success rate 0.0)..."
+answer=$(make_request "$TEST_QUESTION")
+echo "Answer: $answer"
+echo "Expected: 116"
+echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")"
+
+echo ""
+echo "=== Test 3: No Matching Question ==="
+echo "Sending request with non-matching text..."
+response=$(make_request "What is the capital of France?")
+echo "Response: $response"
+echo "Expected: No matching question found"
+echo "Correct: $([ "$response" == "No matching question found" ] && echo "Yes" || echo "No")"
+
+echo ""
+echo "=== Test 4: Success Rate Verification ==="
+echo "Sending 10 requests to test success rate..."
+correct_count=0
+for i in {1..10}; do
+ answer=$(make_request "$TEST_QUESTION")
+ if [ "$answer" == "116" ]; then
+ correct_count=$((correct_count + 1))
+ fi
+ echo " Request $i: Answer = $answer"
+done
+echo "Correct answers: $correct_count/10"
+echo "Expected: ~8/10 (80% success rate)"
+echo "Success rate: $(echo "scale=1; $correct_count * 10" | bc)%"
+
+echo ""
+echo "=== Test Complete ==="
+echo "Stopping simulator..."
+kill $SIMULATOR_PID 2>/dev/null
+wait $SIMULATOR_PID 2>/dev/null || true
+
+echo "Simulator stopped."