From 81a65cf0350488ff6791afc7aae2e6fc861bdee0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 May 2026 18:46:36 +0300 Subject: [PATCH] eval : add Wilson score confidence interval to results Compute 95% CI on-the-fly from completed cases. Displayed in terminal output, HTML report, and JSON state. --- examples/llama-eval/llama-eval.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 646160d1a9..f27f03f479 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -16,6 +16,17 @@ from typing import Dict, List, Optional, Any, Tuple import requests from tqdm import tqdm import random +from math import sqrt + +def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, float]: + """Wilson score confidence interval for a proportion.""" + if total == 0: + return (0.0, 1.0) + p = correct / total + z2 = z * z / total + center = (p + z2 / 2) / (1 + z2) + margin = z * sqrt((p * (1 - p) + z2 / 4) / total) / (1 + z2) + return (center - margin, center + margin) cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" cache_dir.mkdir(parents=True, exist_ok=True) @@ -227,8 +238,9 @@ class EvalState: print(f"Results: 0/0 correct (0.0%)") print(f"{'='*60}") else: + ci_lower, ci_upper = self.accuracy_ci() print(f"\n{'='*60}") - print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%)") + print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]") print(f"{'='*60}") def dump(self): @@ -253,6 +265,7 @@ class EvalState: "reasoning_content": None } + ci_lower, ci_upper = self.accuracy_ci() data = { "id": self.dataset_type, "tasks": [tid for _, tid in tasks_to_save], @@ -260,6 +273,8 @@ class EvalState: "total": self.total, "correct": self.correct, "total_time": self.total_time, + "ci_lower": ci_lower, + "ci_upper": ci_upper, "cases": all_cases, }, "sampling_config": self.sampling_config @@ -278,6 +293,7 @@ class EvalState: incorrect_count = len(completed) - correct_count pending_count = len(tasks_to_save) - len(completed) accuracy = correct_count / len(completed) * 100 if completed else 0.0 + ci_lower, ci_upper = wilson_interval(correct_count, len(completed)) if completed else (0.0, 1.0) sampling_parts = [] for k, v in self.sampling_config.items(): @@ -378,7 +394,7 @@ class EvalState: Correct{correct_count} Incorrect{incorrect_count} Pending{pending_count} - Accuracy{accuracy:.1f}% + Accuracy{accuracy:.1f}% [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%] Total Time{self.total_time:.1f}s Sampling{sampling_str} @@ -510,10 +526,19 @@ class EvalState: print(f"Results: 0/0 correct (0.0%)") print(f"{'='*60}") else: + ci_lower, ci_upper = self.accuracy_ci() print(f"{'='*60}") - print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)") + print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]") print(f"{'='*60}") + def accuracy_ci(self) -> Tuple[float, float]: + """Compute Wilson score confidence interval from completed cases.""" + cases = self.task_states.get("cases", {}) + completed = {tid: c for tid, c in cases.items() if c.get("status") == "ok"} + correct = sum(1 for c in completed.values() if c.get("correct", False)) + total = len(completed) + return wilson_interval(correct, total) + def normalize_number(s: str) -> Optional[int]: match = re.match(r"\d+", s) # match digits from the start if not match: