Files
llama.cpp/examples/llama-eval/llama-eval.py
Georgi Gerganov fde69a3607 examples : add llama-eval (#21152)
* working llama-eval mc and math suite

* multi source llama-eval

* Add readme

* add checkpointing

* examples: add llama-server simulator for testing eval scripts

Add a standalone Python script that simulates a llama-server HTTP endpoint
for testing the eval script. The simulator:

- Implements /v1/chat/completions endpoint with OpenAI-compatible format
- Loads AIME dataset from HuggingFace with local caching
- Uses Levenshtein distance for intelligent question matching
- Supports configurable success rate for correct/wrong answer generation
- Provides debug logging for troubleshooting

Also includes test scripts and documentation for testing and understanding
the simulator functionality.

* examples: refactor test-simulator.sh for better readability

Extract repeating question string into TEST_QUESTION variable and
create make_request() helper function to reduce code duplication.
Add proper error handling for error responses.

* docs: update llama-eval-discussion.md with session work summary

Add summary of llama-server-simulator implementation work including
features, testing results, technical decisions, and refactoring.

* examples: add simplified llama-eval-new.py for AIME evaluation

- Create new simplified evaluation script focused only on AIME
- Implement EvalState and Processor dataclasses for structured state management
- Add real-time feedback showing correct/incorrect status per case
- Abstract grading interface for external grader support
- Use structured JSON output for eval state
- Apply HuggingFace dataset caching to avoid repeated downloads
- Remove Levenshtein matching - eval script only sends requests and validates answers

* docs: remove README.md from llama-eval

* examples: implement flexible grader system for answer validation

- Add Grader class supporting regex and CLI-based grading
- Implement built-in regex patterns for AIME, GSM8K, MMLU, HellaSwag, ARC, WinoGrande
- Add CLI grader interface: python script.py --answer <pred> --expected <gold>
- Add HF telemetry disable to avoid warnings
- Support exact match requirement for regex patterns
- Add 30-second timeout for CLI grader
- Handle both boxed and plain text formats for AIME answers

* examples: use HF_HUB_OFFLINE to avoid HF Hub warnings

* examples: remove HF_HUB_OFFLINE to allow dataset download

* examples: use cached dataset path to avoid HF Hub requests

* examples: use cached dataset path in simulator to avoid HF Hub requests

* docs: update llama-eval-discussion.md with session work summary

* examples: add threading support and model parameter to llama-eval-new.py

- Add ThreadPoolExecutor for parallel request processing controlled by --threads
- Add --model argument to specify model name in request data
- Refactor process() to use thread-safe _process_single_case() method
- Update progress tracking to work with concurrent execution

* docs: update llama-eval-discussion.md with threading and model parameter updates

- Add threading support implementation details
- Document ThreadPoolExecutor usage and thread safety
- Add model parameter implementation details
- Include testing results for both features

* examples: add task summary table to llama-eval-new.py

* eval : print progress

* eval : add prompts

* test : fix path

* sim : fix answer matching

* eval : support multiple dataset runs

* minor

* improve grader

* docs

* remove old files

* datasets : add gsm8k

* add gpqa + sampling + docs

* rename

* grader : improve example answers

* cont

* datasets : add aime2025

* grader : update prompt

* grade : improve regex + logs

* datasets : fix aime2025

* cleanup

* add AGENTS.md

* ignore errors

* resume eval

* cleanup

* fix counts

* simplify

* fix prompts

* add html

* store full response

* add tokens

* resoning and error handling

* refactor

* track total time

* remove junk

* eval : unify "judge" terminology to "grader"

Replace all occurrences of "judge" with "grader" for consistency
across the codebase (CLI args, Grader class fields, help text).

Assisted-by: llama.cpp:local pi

* eval : add Wilson score confidence interval to results

Compute 95% CI on-the-fly from completed cases. Displayed in
terminal output, HTML report, and JSON state.

* llama-eval : add per-task generation speed from server timings

Extract predicted_per_second from the server timings response and store
it as tps_gen per task. Display in console progress, print_all_tasks,
and HTML report.

Assisted-by: llama.cpp:local pi

* llama-eval : add per-task generation time from server timings

Extract predicted_ms from the server timings response and store it as
t_gen_ms per task. Display in seconds with one decimal digit in console
progress, print_all_tasks, and HTML report.

Assisted-by: llama.cpp:local pi

* llama-eval : rename display, escaped, and count variables to use prefix convention

- _display suffix → display_ prefix (answer, tokens, tps, t_gen)
- _escaped suffix → escaped_ prefix (response, prompt, reasoning)
- _count suffix → n_ prefix (correct, incorrect, pending)

Assisted-by: llama.cpp:local pi

* llama-eval : support multiple evaluation endpoints with dynamic task distribution

- Add ServerConfig dataclass (url, threads, name)
- Accept comma-separated --server, --threads, --server-name CLI args
- Dynamic shared-queue task distribution across servers (fast servers do more work)
- One ThreadPoolExecutor per server, workers pull from shared Queue
- Track which server processed each task (server_name in results)
- Thread-safe EvalState with threading.Lock for concurrent mutations
- Server column in HTML report and console output
- Backward compatible: single server works as before

Assisted-by: llama.cpp:local pi

* llama-server-simulator : replace Flask with stdlib http.server

- Use HTTPServer + BaseHTTPRequestHandler instead of Flask
- RequestHandler handles POST /v1/chat/completions
- Server runs in daemon thread with clean Ctrl+C shutdown
- Remove flask and unused asdict imports

Assisted-by: llama.cpp:local pi

* llama-eval : update README with PR link and quick-start examples

Assisted-by: llama.cpp:local pi

* llama-eval : track model name in eval state and verify on resume

- Store model_name in EvalState and JSON output
- Display model in HTML summary table
- Verify --model matches stored model when resuming

Assisted-by: llama.cpp:local pi

* llama-server-simulator : fix comment - Dice coefficient, not Levenshtein

Assisted-by: llama.cpp:local pi

* llama-eval : require --grader-model or --model when using --grader-type llm

Assisted-by: llama.cpp:local pi

* llama-eval : protect dump() with lock for thread safety

Assisted-by: llama.cpp:local pi

* llama-eval : compact HTML report output

- Replace verbose summary table with single inline bar
- Shorten status text: '✓'/'✗'/'–'/'!' instead of full words
- Flatten CSS: remove box-shadows, border-radius, reduce padding
- Use system-ui font, 13px table, 12px details
- Conditional reasoning section (only shown when present)
- Single toggle JS function instead of two
- Shorter column headers

Assisted-by: llama.cpp:local pi

* llama-eval : check server connectivity on startup

- Hit /v1/models for each server before evaluation
- Exit with error if any server is unreachable
- Print comma-separated model IDs per server in startup output
- Sequential checks, no retries, no timeout override

Assisted-by: llama.cpp:local pi

* llama-eval : use server1/server2 instead of gpu1/gpu2 in README

Assisted-by: llama.cpp:local pi

---------

Co-authored-by: gatbontonpc <gatbontonpc@gmail.com>
2026-05-12 15:07:00 +03:00

1417 lines
52 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# type: ignore
import argparse
import json
import os
import re
import subprocess
import sys
import threading
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, asdict, field
from pathlib import Path
from queue import Queue
from typing import Dict, List, Optional, Any, Tuple
import requests
from tqdm import tqdm
import random
from math import sqrt
@dataclass
class ServerConfig:
url: str
threads: int
name: str = ""
def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, float]:
"""Wilson score confidence interval for a proportion."""
if total == 0:
return (0.0, 1.0)
p = correct / total
z2 = z * z / total
center = (p + z2 / 2) / (1 + z2)
margin = z * sqrt((p * (1 - p) + z2 / 4) / total) / (1 + z2)
return (center - margin, center + margin)
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
cache_dir.mkdir(parents=True, exist_ok=True)
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
GRADER_PATTERNS = {
"aime": r'\boxed{(\d+)}|\b(\d+)\b',
"aime2025": r'\boxed{(\d+)}|\b(\d+)\b',
"gsm8k": r'\b(\d+)\b',
}
SAMPLE_ANSWERS = {
"aime": [
"42",
"-123",
"999"
],
"aime2025": [
"42",
"-123",
"999"
],
"gsm8k": [
"42",
"-123",
"999"
],
"gpqa": [
"A",
"D",
"C"
],
}
TEMPLATE_REGISTRY = {
"aime": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}.
{question}
Remember to put your answer inside \\boxed{{}}.
""",
"aime2025": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}.
{question}
Remember to put your answer inside \\boxed{{}}.
""",
"gsm8k": """{question}
Please reason step by step, and put your final numeric answer within \\boxed{{}} without any extra characters.
""",
"gpqa": """Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A').
{Question}
A) {A}
B) {B}
C) {C}
D) {D}
""",
}
class BaseDataset(ABC):
@abstractmethod
def get_question(self, index: int) -> Dict:
pass
@abstractmethod
def get_question_text(self, question: Dict) -> str:
pass
@abstractmethod
def get_answer(self, question: Dict) -> str:
pass
@abstractmethod
def get_prompt(self, question: Dict) -> str:
pass
def __len__(self) -> int:
return len(self.questions)
@dataclass
class TaskState:
task_id: str
prompt: str
expected: str
question_text: str = ""
response: Optional[str] = None
answer: Optional[str] = None
grader_log: Dict[str, Any] = field(default_factory=dict)
correct: bool = False
status: str = "pending"
tokens: Optional[int] = None
tps_gen: Optional[float] = None
t_gen_ms: Optional[float] = None
reasoning_content: Optional[str] = None
server_name: Optional[str] = None
class EvalState:
def __init__(
self,
dataset_type: str,
sampling_config: Dict[str, Any],
output_file: Path = Path("llama-eval-state.json"),
model_name: Optional[str] = None
):
self.dataset_type = dataset_type
self.sampling_config = sampling_config
self.output_file = output_file
self.model_name = model_name
self.dataset: Optional[BaseDataset] = None
self.tasks: List[Tuple[int, str]] = []
self.all_tasks: List[Tuple[int, str]] = []
self.task_states: Dict[str, Any] = {}
self.total = 0
self.correct = 0
self.processed = 0
self.total_time: float = 0.0
self._lock = threading.Lock()
def load_dataset(self, seed: int = 1234):
if self.dataset_type == "aime":
self.dataset = AimeDataset()
elif self.dataset_type == "aime2025":
self.dataset = Aime2025Dataset()
elif self.dataset_type == "gsm8k":
self.dataset = Gsm8kDataset()
elif self.dataset_type == "gpqa":
self.dataset = GpqaDataset(variant="diamond", seed=seed)
else:
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
def setup_tasks(self, n_cases: Optional[int] = None, seed: int = 1234):
if self.dataset is None:
raise ValueError("Dataset not loaded. Call load_dataset() first.")
if n_cases is None:
n_cases = len(self.dataset)
dataset_size = len(self.dataset)
rng = random.Random(seed)
self.tasks = []
for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size):
chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size)
indices = list(range(dataset_size))
rng.shuffle(indices)
chunk_indices = indices[:chunk_size]
for i in chunk_indices:
task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}"
self.tasks.append((i, task_id))
self.all_tasks = list(self.tasks)
def get_case(self, index: int) -> Tuple[str, str, str]:
if self.dataset is None:
raise ValueError("Dataset not loaded.")
question = self.dataset.get_question(index)
question_text = self.dataset.get_question_text(question)
prompt = self.dataset.get_prompt(question)
expected = self.dataset.get_answer(question)
return question_text, prompt, expected
def add_result(
self,
task_id: str,
prompt: str,
expected: str,
response: Optional[str],
answer: Optional[str],
grader_log: Dict[str, Any],
correct: bool,
status: str,
tokens: Optional[int] = None,
tps_gen: Optional[float] = None,
t_gen_ms: Optional[float] = None,
reasoning_content: Optional[str] = None,
server_name: Optional[str] = None
):
with self._lock:
if "cases" not in self.task_states:
self.task_states["cases"] = {}
self.task_states["cases"][task_id] = {
"task_id": task_id,
"prompt": prompt,
"expected": expected,
"response": response,
"answer": answer,
"grader_log": grader_log,
"correct": correct,
"status": status,
"tokens": tokens,
"tps_gen": tps_gen,
"t_gen_ms": t_gen_ms,
"reasoning_content": reasoning_content,
"server_name": server_name
}
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
def print_progress(self, task_state: TaskState, total_tasks: int, n_correct: int = 0):
display_answer = task_state.answer if task_state.answer else "N/A"
display_tokens = str(task_state.tokens) if task_state.tokens is not None else "N/A"
display_tps = f"{task_state.tps_gen:.1f}" if task_state.tps_gen is not None else "N/A"
display_t_gen = f"{task_state.t_gen_ms/1000:.1f}" if task_state.t_gen_ms is not None else "N/A"
display_server = task_state.server_name if task_state.server_name else "N/A"
success_ratio = n_correct / self.processed if self.processed > 0 else 0.0
first_line = task_state.question_text.split('\n')[0]
truncated_question = first_line[:43]
if len(first_line) > 43:
truncated_question += "..."
else:
truncated_question = truncated_question.ljust(43) + "..."
print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {display_answer:<10} {display_tokens:<6} {display_tps:<6} {display_t_gen:<8} {'' if task_state.correct else ''} [{n_correct:3}/{self.processed:3}, {success_ratio:.3f}] {display_server}")
def print_summary(self):
if self.total == 0:
print(f"\n{'='*60}")
print(f"Results: 0/0 correct (0.0%)")
print(f"{'='*60}")
else:
ci_lower, ci_upper = self.accuracy_ci()
print(f"\n{'='*60}")
print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]")
print(f"{'='*60}")
def dump(self):
with self._lock:
tasks_to_save = self.all_tasks if self.all_tasks else self.tasks
all_cases = {}
for i, task_id in tasks_to_save:
question_text, prompt, expected = self.get_case(i)
if task_id in self.task_states.get("cases", {}):
all_cases[task_id] = self.task_states["cases"][task_id]
else:
all_cases[task_id] = {
"task_id": task_id,
"prompt": prompt,
"expected": expected,
"question_text": question_text,
"response": None,
"answer": None,
"grader_log": {},
"correct": False,
"status": "pending",
"tokens": None,
"tps_gen": None,
"t_gen_ms": None,
"reasoning_content": None,
"server_name": None
}
ci_lower, ci_upper = self.accuracy_ci()
data = {
"id": self.dataset_type,
"model_name": self.model_name,
"tasks": [tid for _, tid in tasks_to_save],
"task_states": {
"total": self.total,
"correct": self.correct,
"total_time": self.total_time,
"ci_lower": ci_lower,
"ci_upper": ci_upper,
"cases": all_cases,
},
"sampling_config": self.sampling_config
}
with open(self.output_file, "w") as f:
json.dump(data, f, indent=2)
self.dump_html(tasks_to_save, all_cases)
def dump_html(self, tasks_to_save: List[Tuple[int, str]], all_cases: Dict[str, Any]):
html_file = Path(str(self.output_file) + ".html")
cases = all_cases
completed = {tid: c for tid, c in cases.items() if c.get("status") == "ok"}
n_correct = sum(1 for c in completed.values() if c.get("correct", False))
n_incorrect = len(completed) - n_correct
n_pending = len(tasks_to_save) - len(completed)
accuracy = n_correct / len(completed) * 100 if completed else 0.0
ci_lower, ci_upper = wilson_interval(n_correct, len(completed)) if completed else (0.0, 1.0)
sampling_parts = []
for k, v in self.sampling_config.items():
if v is not None:
sampling_parts.append(f"{k}={v}")
sampling_str = ", ".join(sampling_parts) if sampling_parts else "default"
rows = []
for i, task_id in tasks_to_save:
case = cases.get(task_id, {})
status = case.get("status", "pending")
expected = case.get("expected", "")
answer = case.get("answer", "") if status == "ok" else ""
is_correct = case.get("correct", False) if status == "ok" else False
response = case.get("response", "") or ""
prompt = case.get("prompt", "") or ""
grader_log = case.get("grader_log", {})
if status == "ok":
status_class = "correct" if is_correct else "incorrect"
status_text = "" if is_correct else ""
elif status == "pending":
status_class = "pending"
status_text = ""
else:
status_class = "error"
status_text = "!"
tokens = case.get("tokens")
tokens_str = str(tokens) if tokens is not None else ""
tps_gen = case.get("tps_gen")
tps_str = f"{tps_gen:.1f}" if tps_gen is not None else ""
t_gen_ms = case.get("t_gen_ms")
t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else ""
reasoning_content = case.get("reasoning_content", "") or ""
server_name = case.get("server_name", "") or ""
escaped_response = self._escape_html(response)
escaped_prompt = self._escape_html(prompt)
escaped_reasoning = self._escape_html(reasoning_content)
grader_log_str = self._escape_html(json.dumps(grader_log, indent=2))
escaped_server = self._escape_html(server_name)
rows.append(f"""<tr class="task-row" onclick="toggleDetails('{task_id}')">
<td>{task_id}</td>
<td class="{status_class}">{status_text}</td>
<td>{self._escape_html(expected)}</td>
<td>{self._escape_html(answer)}</td>
<td>{tokens_str}</td>
<td>{tps_str}</td>
<td>{t_gen_str}</td>
<td>{escaped_server}</td>
</tr>
<tr id="details-{task_id}" class="details-row">
<td colspan="8">
<div class="details-content">
<b>Prompt</b><pre>{escaped_prompt}</pre>
<b>Response</b><pre>{escaped_response}</pre>
{f'<b>Reasoning</b><pre>{escaped_reasoning}</pre>' if escaped_reasoning else ''}
<b>Grader</b><pre>{grader_log_str}</pre>
</div>
</td>
</tr>""")
rows_html = "\n".join(rows)
html_content = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>{self.dataset_type.upper()} Eval</title>
<style>
body {{ font-family: system-ui, sans-serif; margin: 0; padding: 16px; background: #fff; color: #222; }}
.bar {{ padding: 8px 0; font-size: 14px; color: #555; }}
.bar span {{ margin-right: 20px; }}
.bar b {{ color: #222; }}
table {{ width: 100%; border-collapse: collapse; font-size: 13px; }}
th {{ text-align: left; padding: 6px 8px; border-bottom: 2px solid #ccc; font-weight: 600; }}
td {{ padding: 4px 8px; border-bottom: 1px solid #eee; vertical-align: top; }}
.task-row {{ cursor: pointer; }}
.task-row:hover {{ background: #f5f5f5; }}
.correct {{ color: #1a7f37; }}
.incorrect {{ color: #cf222e; }}
.pending {{ color: #888; }}
.error {{ color: #9a6700; }}
.details-row {{ display: none; }}
.details-row.open {{ display: table-row; }}
.details-content {{ padding: 8px 16px; background: #f6f8fa; font-size: 12px; }}
.details-content b {{ color: #555; }}
.details-content pre {{ background: #fff; border: 1px solid #e1e4e8; padding: 8px; overflow-x: auto; white-space: pre-wrap; word-wrap: break-word; margin: 4px 0 8px; }}
</style>
</head>
<body>
<div class="bar">
<span><b>{self.dataset_type.upper()}</b></span>
<span>Model: {self.model_name or 'N/A'}</span>
<span>Accuracy: <b>{accuracy:.1f}%</b> [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]</span>
<span>Correct: <span class="correct">{n_correct}</span> / {len(completed)}</span>
<span>Pending: {n_pending}</span>
<span>Time: {self.total_time:.1f}s</span>
<span>Sampling: {sampling_str}</span>
</div>
<table>
<thead>
<tr>
<th>ID</th>
<th></th>
<th>Gold</th>
<th>Answer</th>
<th>Tokens</th>
<th>T/s</th>
<th>Gen s</th>
<th>Server</th>
</tr>
</thead>
<tbody>
{rows_html}
</tbody>
</table>
<script>
function toggleDetails(id) {{ document.getElementById('details-'+id).classList.toggle('open'); }}
</script>
</body>
</html>"""
with open(html_file, "w") as f:
f.write(html_content)
def _escape_html(self, s: str) -> str:
return (s.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&#39;"))
@classmethod
def load(cls, path: Path) -> "EvalState":
with open(path, "r") as f:
data = json.load(f)
eval_state = cls(
dataset_type=data["id"],
sampling_config=data["sampling_config"],
output_file=path,
model_name=data.get("model_name")
)
eval_state.load_dataset()
eval_state.tasks = []
eval_state.all_tasks = []
for task_id in data.get("tasks", []):
parts = task_id.rsplit("_", 2)
if len(parts) >= 3:
idx = int(parts[-1])
else:
idx = 0
eval_state.tasks.append((idx, task_id))
eval_state.all_tasks.append((idx, task_id))
eval_state.task_states = data.get("task_states", {})
cases = eval_state.task_states.get("cases", {})
eval_state.total = eval_state.task_states.get("total", 0)
eval_state.correct = eval_state.task_states.get("correct", 0)
eval_state.total_time = eval_state.task_states.get("total_time", 0.0)
if eval_state.total == 0:
eval_state.total = len(cases)
eval_state.correct = sum(1 for c in cases.values() if c.get("correct", False))
return eval_state
def is_complete(self) -> bool:
if not self.all_tasks:
return False
cases = self.task_states.get("cases", {})
completed = {tid for tid in self.task_states.get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"}
return len(completed) == len(self.all_tasks)
def get_pending_tasks(self) -> List[Tuple[int, str]]:
cases = self.task_states.get("cases", {})
pending = []
for i, task_id in self.all_tasks:
status = cases.get(task_id, {}).get("status", "pending")
if status != "ok":
pending.append((i, task_id))
return pending
def print_all_tasks(self):
cases = self.task_states.get("cases", {})
tasks_to_show = self.all_tasks if self.all_tasks else self.tasks
print()
print("Tasks:")
print(" Task ID Dataset Prompt (first 40 chars) Expected Answer Tokens T/s Gen s Status")
for i, task_id in tasks_to_show:
question, prompt, expected = self.get_case(i)
case = cases.get(task_id, {})
status = case.get("status", "pending")
answer = case.get("answer", "N/A") if status == "ok" else "N/A"
tokens = case.get("tokens")
tokens_str = str(tokens) if tokens is not None else "N/A"
tps_gen = case.get("tps_gen")
tps_str = f"{tps_gen:.1f}" if tps_gen is not None else "N/A"
t_gen_ms = case.get("t_gen_ms")
t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else "N/A"
server_name = case.get("server_name", "") or ""
is_correct = case.get("correct", False) if status == "ok" else False
symbol = "" if is_correct else ("" if status == "ok" else "")
first_line = question.split('\n')[0]
question_trunc = first_line[:43]
if len(first_line) > 43:
question_trunc += "..."
else:
question_trunc = question_trunc.ljust(43) + "..."
print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {tps_str:<6} {t_gen_str:<8} {symbol}{status} {server_name}")
print()
def print_existing_summary(self):
cases = self.task_states.get("cases", {})
completed_cases = {tid: c for tid, c in cases.items() if c.get("status") == "ok"}
correct = sum(1 for c in completed_cases.values() if c.get("correct", False))
total = len(completed_cases)
if total == 0:
print(f"{'='*60}")
print(f"Results: 0/0 correct (0.0%)")
print(f"{'='*60}")
else:
ci_lower, ci_upper = self.accuracy_ci()
print(f"{'='*60}")
print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]")
print(f"{'='*60}")
def accuracy_ci(self) -> Tuple[float, float]:
"""Compute Wilson score confidence interval from completed cases."""
cases = self.task_states.get("cases", {})
completed = {tid: c for tid, c in cases.items() if c.get("status") == "ok"}
correct = sum(1 for c in completed.values() if c.get("correct", False))
total = len(completed)
return wilson_interval(correct, total)
def normalize_number(s: str) -> Optional[int]:
match = re.match(r"\d+", s) # match digits from the start
if not match:
return None
return int(match.group(0))
class AimeDataset(BaseDataset):
def __init__(self, split: str = "train"):
self.split = split
self.questions: List[Dict] = []
self._load_dataset()
def _load_dataset(self):
print(f"Loading AIME dataset (split: {self.split})...")
from datasets import load_dataset
cache_path = cache_dir / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
if cache_path.exists():
print(f"Using cached dataset from {cache_path}")
ds = load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
else:
ds = load_dataset("AI-MO/aimo-validation-aime", split=self.split)
self.questions = []
for row in ds:
question = dict(row)
question["dataset_type"] = "aime"
self.questions.append(question)
print(f"AIME dataset loaded: {len(self.questions)} questions")
def get_question(self, index: int) -> Dict:
"""Get question by index"""
return self.questions[index]
def get_question_text(self, question: Dict) -> str:
"""Get question string"""
return question["problem"] if "problem" in question else question["question"]
def get_answer(self, question: Dict) -> str:
answer = question["answer"]
if isinstance(answer, str):
normalized = normalize_number(answer)
return str(normalized) if normalized is not None else answer
return str(answer)
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
question=self.get_question_text(question),
)
class Aime2025Dataset(BaseDataset):
def __init__(self):
self.questions: List[Dict] = []
self._load_dataset()
def _load_dataset(self):
print(f"Loading AIME2025 dataset...")
from datasets import load_dataset
config_name = "AIME2025-I"
cache_path = cache_dir / "opencompass___AIME2025" / "default" / "0.0.0"
if cache_path.exists():
print(f"Using cached dataset from {cache_path}")
ds = load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path))
else:
ds = load_dataset("opencompass/AIME2025", config_name, split="test")
self.questions = []
for row in ds:
question = dict(row)
question["dataset_type"] = "aime2025"
self.questions.append(question)
print(f"AIME2025 dataset loaded: {len(self.questions)} questions")
print(f"Loading AIME2025 dataset (part 2)...")
config_name_2 = "AIME2025-II"
cache_path_2 = cache_dir / "opencompass___AIME2025" / "default" / "0.0.0"
if cache_path_2.exists():
print(f"Using cached dataset from {cache_path_2}")
ds_2 = load_dataset("opencompass/AIME2025", config_name_2, split="test", cache_dir=str(cache_path_2))
else:
ds_2 = load_dataset("opencompass/AIME2025", config_name_2, split="test")
for row in ds_2:
question = dict(row)
question["dataset_type"] = "aime2025"
self.questions.append(question)
print(f"AIME2025 dataset loaded: {len(self.questions)} questions (total)")
def get_question(self, index: int) -> Dict:
"""Get question by index"""
return self.questions[index]
def get_question_text(self, question: Dict) -> str:
"""Get question string"""
return question["question"]
def get_answer(self, question: Dict) -> str:
answer = question["answer"]
if isinstance(answer, str):
normalized = normalize_number(answer)
return str(normalized) if normalized is not None else answer
return str(answer)
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
return TEMPLATE_REGISTRY["aime2025"].format(
question=self.get_question_text(question),
)
class Gsm8kDataset(BaseDataset):
def __init__(self, split: str = "test"):
self.split = split
self.questions: List[Dict] = []
self._load_dataset()
def _load_dataset(self):
print(f"Loading GSM8K dataset (split: {self.split})...")
from datasets import load_dataset
cache_path = cache_dir / "openai___gsm8k" / "default" / "0.0.0"
if cache_path.exists():
print(f"Using cached dataset from {cache_path}")
ds = load_dataset("openai/gsm8k", "main", split=self.split, cache_dir=str(cache_path))
else:
ds = load_dataset("openai/gsm8k", "main", split=self.split)
self.questions = []
for row in ds:
question = dict(row)
question["dataset_type"] = "gsm8k"
# Extract numeric answer from the answer field (already has #### prefix)
gold = question["answer"]
# Split by #### and take the last part
parts = gold.split("####")
if len(parts) > 1:
gold = parts[-1].strip()
# Extract the first number from the remaining text
normalized = normalize_number(gold)
question["gold"] = str(normalized) if normalized is not None else gold
self.questions.append(question)
print(f"GSM8K dataset loaded: {len(self.questions)} questions")
def get_question(self, index: int) -> Dict:
"""Get question by index"""
return self.questions[index]
def get_question_text(self, question: Dict) -> str:
"""Get question string"""
return question["problem"] if "problem" in question else question["question"]
def get_answer(self, question: Dict) -> str:
# GSM8K has pre-extracted gold field, AIME uses answer field
if "gold" in question:
return question["gold"]
answer = question["answer"]
if isinstance(answer, str):
normalized = normalize_number(answer)
return str(normalized) if normalized is not None else answer
return str(answer)
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
question=self.get_question_text(question),
)
class GpqaDataset(BaseDataset):
def __init__(self, variant: str = "diamond", seed: int = 1234):
self.variant = variant
self.seed = seed
self.questions: List[Dict] = []
self._load_dataset()
def _load_dataset(self):
print(f"Loading GPQA dataset (variant: {self.variant})...")
import pandas as pd
url = f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{self.variant}.csv"
df = pd.read_csv(url)
rng = random.Random(self.seed)
self.questions = []
for _, row in df.iterrows():
question = row.to_dict()
question["dataset_type"] = "gpqa"
# Shuffle the answer options
correct_answer = question["Correct Answer"]
incorrect_answers = [
question["Incorrect Answer 1"],
question["Incorrect Answer 2"],
question["Incorrect Answer 3"]
]
# Create list of (answer, is_correct) tuples
options = [(ans, ans == correct_answer) for ans in incorrect_answers]
options.append((correct_answer, True))
# Shuffle the options
rng.shuffle(options)
# Extract shuffled answers and determine correct letter
shuffled_answers = [ans for ans, _ in options]
correct_letter = chr(ord('A') + options.index((correct_answer, True)))
# Store shuffled answers and correct letter
question["shuffled_answers"] = shuffled_answers
question["correct_letter"] = correct_letter
self.questions.append(question)
print(f"GPQA dataset loaded: {len(self.questions)} questions")
def get_question(self, index: int) -> Dict:
"""Get question by index"""
return self.questions[index]
def get_question_text(self, question: Dict) -> str:
"""Get question string"""
return question["Question"]
def get_answer(self, question: Dict) -> str:
# GPQA returns the correct letter (A, B, C, or D)
return question["correct_letter"]
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
return TEMPLATE_REGISTRY["gpqa"].format(
Question=self.get_question_text(question),
A=question["shuffled_answers"][0],
B=question["shuffled_answers"][1],
C=question["shuffled_answers"][2],
D=question["shuffled_answers"][3]
)
class Grader:
def __init__(
self,
grader_type: str = "llm",
grader_script: Optional[str] = None,
grader_model_name: Optional[str] = None,
grader_server_url: str = "",
dataset_type: str = "aime"
):
self.grader_type = grader_type
self.grader_script = grader_script
self.grader_model_name = grader_model_name
self.grader_server_url = grader_server_url
self.dataset_type = dataset_type
self.pattern = self._get_pattern()
def _get_pattern(self) -> Optional[str]:
if self.grader_type == "regex":
return GRADER_PATTERNS.get(self.dataset_type) # Use dataset_type as key
return None
def _extract_answer_regex(self, pred: str) -> Optional[str]:
"""Extract answer using regex pattern"""
if not self.pattern:
return None
# For AIME datasets, prioritize boxed answers
if self.dataset_type in ["aime", "aime2025"]:
boxed_pattern = r'\\boxed{([^}]+)}'
boxed_matches = re.findall(boxed_pattern, pred, re.IGNORECASE)
if boxed_matches:
# Return the last boxed answer found (most likely the final answer)
return boxed_matches[-1].strip()
# For other datasets, search for numbers from the end of the text
# This prioritizes numbers that appear later in the response
matches = re.findall(self.pattern, pred, re.IGNORECASE)
if not matches:
return None
# Process matches from end to start
for match in reversed(matches):
if isinstance(match, tuple):
match = match[0] if match[0] else match[1]
answer = match.strip()
if answer:
return answer
return None
def _grade_regex(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
"""Grade using regex pattern matching"""
answer = self._extract_answer_regex(pred)
if answer is None:
return False, None
is_correct = answer.strip() == gold.strip()
return is_correct, answer
def _grade_cli(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
"""Grade using external CLI script"""
if not self.grader_script:
raise ValueError("CLI grader requires --grader-script")
script_path = Path(self.grader_script)
if not script_path.exists():
raise FileNotFoundError(f"Grader script not found: {self.grader_script}")
try:
result = subprocess.run(
[str(script_path), "--answer", pred, "--expected", gold],
capture_output=True,
text=True,
timeout=30
)
is_correct = result.returncode == 0
answer = pred if is_correct else None
return is_correct, answer
except subprocess.TimeoutExpired:
return False, None
except Exception as e:
return False, None
def _grade_llm(self, gold: str, pred: str, problem: str) -> Tuple[bool, Optional[str]]:
"""Grade using LLM-based extraction with few-shot examples"""
sample_answers = SAMPLE_ANSWERS.get(self.dataset_type, [])
sample_examples = "\n".join([
f"Example {i+1}: {ans}" for i, ans in enumerate(sample_answers)
])
system_prompt = f"""You are an answer extraction system. Your task is to extract the answer from the model's response.
Here are some examples of extracted answers to demonstrate what you are supposed to output:
{sample_examples}
When extracting the answer, provide only the extracted answer itself, nothing else. If there is no clear answer that can be extracted from the response, reply with 'no answer'."""
user_prompt = f"""Extract the answer from the following response:
"{pred}"
Please provide only the extracted answer, nothing else. If there is no clear answer that can be extracted from the response, reply with 'no answer'."""
url = f"{self.grader_server_url}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": self.grader_model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"temperature": 0,
}
#print(json.dumps(data, indent=2))
try:
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
answer = response.json()["choices"][0]["message"]["content"].strip()
is_correct = answer.strip().lower() == gold.strip().lower()
return is_correct, answer
except Exception as e:
return False, None
def _truncate_response(self, response: str, max_lines: int = 6) -> str:
"""Keep only last N lines of response"""
lines = response.split('\n')
return '\n'.join(lines[-max_lines:]) if len(lines) > max_lines else response
def grade(self, gold: str, pred: str, problem: str = "") -> Tuple[bool, Optional[str]]:
"""Grade the response"""
if self.grader_type == "regex":
return self._grade_regex(gold, pred)
elif self.grader_type == "cli":
return self._grade_cli(gold, pred)
elif self.grader_type == "llm":
return self._grade_llm(gold, pred, problem)
else:
raise ValueError(f"Unknown grader type: {self.grader_type}")
class Processor:
def __init__(
self,
server_configs: List[ServerConfig],
grader: Grader,
model_name: Optional[str] = None,
n_predict: int = -1
):
self.server_configs = server_configs
self.grader = grader
self.model_name = model_name
self.n_predict = n_predict
@staticmethod
def _check_server(server_config: ServerConfig) -> List[str]:
url = f"{server_config.url}/v1/models"
try:
response = requests.get(url)
response.raise_for_status()
models = [m["id"] for m in response.json().get("data", [])]
return models
except Exception as e:
print(f"Error: Cannot reach server {server_config.name} ({server_config.url}): {e}", file=sys.stderr)
sys.exit(1)
def _make_request(
self, server_config: ServerConfig, eval_state: EvalState, prompt: str
) -> Tuple[Dict[str, Any], int, Optional[float], Optional[float], str]:
url = f"{server_config.url}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": self.model_name if self.model_name else "llama",
"messages": [{"role": "user", "content": prompt}],
"n_predict": self.n_predict
}
if eval_state.sampling_config.get("temperature") is not None:
data["temperature"] = eval_state.sampling_config["temperature"]
if eval_state.sampling_config.get("top_k") is not None:
data["top_k"] = eval_state.sampling_config["top_k"]
if eval_state.sampling_config.get("top_p") is not None:
data["top_p"] = eval_state.sampling_config["top_p"]
if eval_state.sampling_config.get("min_p") is not None:
data["min_p"] = eval_state.sampling_config["min_p"]
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
result = response.json()
tokens = result.get("usage", {}).get("completion_tokens", 0)
timings = result.get("timings", {})
tps_gen = timings.get("predicted_per_second") if timings else None
t_gen_ms = timings.get("predicted_ms") if timings else None
finish_reason = result.get("choices", [{}])[0].get("finish_reason", "stop")
return result, tokens, tps_gen, t_gen_ms, finish_reason
def _process_single_case(
self, server_config: ServerConfig, eval_state: EvalState, i: int, task_id: str
) -> TaskState:
question_text, prompt, expected = eval_state.get_case(i)
task_state = TaskState(
task_id=task_id,
prompt=prompt,
expected=expected,
question_text=question_text,
server_name=server_config.name
)
try:
response, tokens, tps_gen, t_gen_ms, finish_reason = self._make_request(server_config, eval_state, prompt)
result = response["choices"][0]["message"]["content"]
reasoning_content = response["choices"][0].get("message", {}).get("reasoning_content")
task_state.response = result
task_state.tokens = tokens
task_state.tps_gen = tps_gen
task_state.t_gen_ms = t_gen_ms
task_state.reasoning_content = reasoning_content
if finish_reason != "stop":
task_state.status = f"error: finish_reason={finish_reason}"
eval_state.add_result(
task_id, prompt, expected, result, None,
{"finish_reason": finish_reason}, False, task_state.status,
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
)
eval_state.dump()
return task_state
result_truncated = self.grader._truncate_response(result, max_lines=10)
is_correct, answer = self.grader.grade(expected, result_truncated, prompt)
grader_log = {
"pred": result_truncated,
"grader_type": self.grader.grader_type
}
if self.grader.grader_type == "regex" and self.grader.pattern:
grader_log["pattern"] = self.grader.pattern
task_state.correct = is_correct
task_state.answer = answer
task_state.grader_log = grader_log
task_state.status = "ok"
eval_state.add_result(
task_id, prompt, expected, result, answer,
grader_log, is_correct, "ok",
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
)
eval_state.dump()
except Exception as e:
task_state.status = f"error: {str(e)}"
return task_state
@staticmethod
def _worker(
server_config: ServerConfig,
processor: "Processor",
eval_state: EvalState,
task_queue: Queue,
results_queue: Queue,
):
"""Worker that pulls tasks from a shared queue and sends them to its server."""
while True:
task = task_queue.get()
if task is None: # sentinel
task_queue.task_done()
break
try:
i, task_id = task
result = processor._process_single_case(server_config, eval_state, i, task_id)
results_queue.put(result)
finally:
task_queue.task_done()
def evaluate(self, eval_state: EvalState, verbose: bool = False, resume: bool = False):
total_tasks = len(eval_state.tasks)
eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks
eval_state.processed = 0
start_time = time.time()
# Check servers and list models
server_models = [self._check_server(sc) for sc in self.server_configs]
# Print server info
print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} tasks ...")
print(f"Servers ({len(self.server_configs)}):")
for i, sc in enumerate(self.server_configs):
models_str = ", ".join(server_models[i]) if server_models[i] else "(none)"
print(f" {i+1}. {sc.name}{sc.url} ({sc.threads} threads) [{models_str}]")
print(f"Model: {self.model_name}")
print(f"Grader: {self.grader.grader_type}")
print(f"Sampling: temp={eval_state.sampling_config.get('temperature', 'skip')}, top-k={eval_state.sampling_config.get('top_k', 'skip')}, top-p={eval_state.sampling_config.get('top_p', 'skip')}, min-p={eval_state.sampling_config.get('min_p', 'skip')}")
print()
# Shared task queue: all workers compete for tasks
task_queue: Queue = Queue()
for i, task_id in eval_state.tasks:
task_queue.put((i, task_id))
# Results queue: workers push completed TaskStates here
results_queue: Queue = Queue()
# Total worker threads across all servers
total_threads = sum(sc.threads for sc in self.server_configs)
# Add one sentinel per worker so every worker exits cleanly
for _ in range(total_threads):
task_queue.put(None)
# Launch workers: one ThreadPoolExecutor per server
executors: List[ThreadPoolExecutor] = []
worker_futures: List[Any] = []
for server_config in self.server_configs:
executor = ThreadPoolExecutor(max_workers=server_config.threads)
executors.append(executor)
for _ in range(server_config.threads):
future = executor.submit(
self._worker, server_config, self, eval_state,
task_queue, results_queue
)
worker_futures.append(future)
# Drain results as they complete
n_correct = 0
session_time = 0.0
completed_count = 0
while completed_count < total_tasks:
task_state = results_queue.get()
eval_state.processed += 1
completed_count += 1
if task_state.correct:
n_correct += 1
elapsed = time.time() - start_time
eval_state.total_time += elapsed
session_time += elapsed
start_time = time.time()
eval_state.print_progress(task_state, total_tasks, n_correct)
if verbose:
print(f"\nCase {eval_state.processed}: {task_state.correct}")
print(f" Expected: {task_state.expected}")
if task_state.response:
print(f" Response: {task_state.response}")
if task_state.answer:
print(f" Answer: {task_state.answer}")
print(f" Status: {task_state.status}")
# Wait for all workers to finish and shut down executors
for future in worker_futures:
future.result()
for executor in executors:
executor.shutdown(wait=True)
print(f"\nSession time: {session_time:.1f}s | Total accumulated time: {eval_state.total_time:.1f}s")
eval_state.print_summary()
eval_state.dump()
def main():
parser = argparse.ArgumentParser(
description="Simplified evaluation tool for llama.cpp"
)
parser.add_argument(
"--server",
type=str,
default="http://localhost:8033",
help="Comma-separated llama-server URLs (default: http://localhost:8033)"
)
parser.add_argument(
"--server-name",
type=str,
default="",
help="Comma-separated display names for servers (default: use URLs)"
)
parser.add_argument(
"--dataset",
type=str,
default="aime",
choices=["aime", "aime2025", "gsm8k", "gpqa"],
help="Dataset type (default: aime)"
)
parser.add_argument(
"--n_cases",
type=int,
default=None,
help="Number of cases to evaluate (default: all)"
)
parser.add_argument(
"--seed",
type=int,
default=1234,
help="Random seed for shuffling (default: 1234)"
)
parser.add_argument(
"--n_predict",
type=int,
default=-1,
help="Max tokens to predict per prompt (default: -1, infinite)"
)
parser.add_argument(
"--temperature",
type=float,
default=None,
help="Sampling temperature (default: not passed)"
)
parser.add_argument(
"--top-k",
type=int,
default=None,
help="Top K sampling (default: not passed)"
)
parser.add_argument(
"--top-p",
type=float,
default=None,
help="Top P sampling (default: not passed)"
)
parser.add_argument(
"--min-p",
type=float,
default=None,
help="Min P sampling (default: not passed)"
)
parser.add_argument(
"--threads",
type=str,
default="32",
help="Comma-separated thread counts per server (default: 32)"
)
parser.add_argument(
"--model",
type=str,
default=None,
help="Model name to append as query parameter (e.g., gpt-oss-20b-hf)"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Show detailed output for each case"
)
parser.add_argument(
"--output",
type=Path,
default=Path("llama-eval-state.json"),
help="Output file for eval state (default: llama-eval-state.json)"
)
parser.add_argument(
"--grader-type",
type=str,
default="llm",
choices=["regex", "cli", "llm"],
help="Grader type: regex, cli, or llm (default: llm)"
)
parser.add_argument(
"--grader-script",
type=str,
default=None,
help="CLI grader script path (required for --grader-type cli)"
)
parser.add_argument(
"--grader-server",
type=str,
default="",
help="Server URL for LLM grader (default: same as main server)"
)
parser.add_argument(
"--grader-model",
type=str,
default="",
help="Model name for LLM grader (default: same as main model)"
)
parser.add_argument(
"--resume",
action="store_true",
help="Resume from existing eval state"
)
args = parser.parse_args()
# Parse server URLs and thread counts
server_urls = [u.strip() for u in args.server.split(",") if u.strip()]
thread_counts = [int(t.strip()) for t in args.threads.split(",") if t.strip()]
if len(server_urls) != len(thread_counts):
print(f"Error: --server ({len(server_urls)} URLs) and --threads ({len(thread_counts)} values) must have the same count")
sys.exit(1)
# Parse server names (optional, defaults to URLs)
if args.server_name:
server_names = [n.strip() for n in args.server_name.split(",") if n.strip()]
if len(server_names) != len(server_urls):
print(f"Error: --server-name ({len(server_names)} names) and --server ({len(server_urls)} URLs) must have the same count")
sys.exit(1)
else:
server_names = server_urls # fallback to URLs
server_configs = [
ServerConfig(url=url, threads=threads, name=name)
for url, threads, name in zip(server_urls, thread_counts, server_names)
]
if args.dataset == "gpqa" and args.grader_type != "llm":
print("Error: GPQA dataset requires --grader-type llm")
parser.print_help()
sys.exit(1)
if args.output.exists():
print(f"Loading existing eval state from {args.output}")
eval_state = EvalState.load(args.output)
# Verify model matches
if eval_state.model_name is not None and args.model != eval_state.model_name:
print(f"Error: Model mismatch. State has '{eval_state.model_name}', but --model is '{args.model}'")
sys.exit(1)
eval_state.print_all_tasks()
eval_state.print_existing_summary()
if eval_state.is_complete():
return
print()
if not args.resume:
print(f"Evaluation incomplete. Run with --resume to continue.")
return
pending_tasks = eval_state.get_pending_tasks()
print(f"Resuming from {len(pending_tasks)} pending tasks")
existing_cases = eval_state.task_states.get("cases", {})
eval_state.tasks = pending_tasks
eval_state.task_states["cases"] = existing_cases
grader_server_url = args.grader_server if args.grader_server else server_configs[0].url
grader_model_name = args.grader_model if args.grader_model else args.model
if args.grader_type == "llm" and not grader_model_name:
print("Error: --grader-type llm requires --grader-model or --model")
sys.exit(1)
grader = Grader(
grader_type=args.grader_type,
grader_script=args.grader_script,
grader_model_name=grader_model_name,
grader_server_url=grader_server_url,
dataset_type=eval_state.dataset_type
)
resume = True
else:
if args.resume:
print("Error: No existing eval state found to resume")
sys.exit(1)
grader_server_url = args.grader_server if args.grader_server else server_configs[0].url
grader_model_name = args.grader_model if args.grader_model else args.model
if args.grader_type == "llm" and not grader_model_name:
print("Error: --grader-type llm requires --grader-model or --model")
sys.exit(1)
grader = Grader(
grader_type=args.grader_type,
grader_script=args.grader_script,
grader_model_name=grader_model_name,
grader_server_url=grader_server_url,
dataset_type=args.dataset
)
if args.grader_type == "llm" and not args.grader_server:
print("Warning: Using same server for LLM grader (no --grader-server specified)")
sampling_config = {}
if args.temperature is not None:
sampling_config["temperature"] = args.temperature
if args.top_k is not None:
sampling_config["top_k"] = args.top_k
if args.top_p is not None:
sampling_config["top_p"] = args.top_p
if args.min_p is not None:
sampling_config["min_p"] = args.min_p
eval_state = EvalState(
dataset_type=args.dataset,
sampling_config=sampling_config,
output_file=args.output,
model_name=args.model
)
eval_state.load_dataset(seed=args.seed)
eval_state.setup_tasks(n_cases=args.n_cases, seed=args.seed)
eval_state.dump()
resume = False
eval_state.print_all_tasks()
processor = Processor(
server_configs=server_configs,
grader=grader,
model_name=args.model,
n_predict=args.n_predict
)
processor.evaluate(eval_state, verbose=args.verbose, resume=resume)
print(f"\nEval state dumped to {args.output}")
if __name__ == "__main__":
main()