diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 646160d1a9..f27f03f479 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -16,6 +16,17 @@ from typing import Dict, List, Optional, Any, Tuple import requests from tqdm import tqdm import random +from math import sqrt + +def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, float]: + """Wilson score confidence interval for a proportion.""" + if total == 0: + return (0.0, 1.0) + p = correct / total + z2 = z * z / total + center = (p + z2 / 2) / (1 + z2) + margin = z * sqrt((p * (1 - p) + z2 / 4) / total) / (1 + z2) + return (center - margin, center + margin) cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" cache_dir.mkdir(parents=True, exist_ok=True) @@ -227,8 +238,9 @@ class EvalState: print(f"Results: 0/0 correct (0.0%)") print(f"{'='*60}") else: + ci_lower, ci_upper = self.accuracy_ci() print(f"\n{'='*60}") - print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%)") + print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]") print(f"{'='*60}") def dump(self): @@ -253,6 +265,7 @@ class EvalState: "reasoning_content": None } + ci_lower, ci_upper = self.accuracy_ci() data = { "id": self.dataset_type, "tasks": [tid for _, tid in tasks_to_save], @@ -260,6 +273,8 @@ class EvalState: "total": self.total, "correct": self.correct, "total_time": self.total_time, + "ci_lower": ci_lower, + "ci_upper": ci_upper, "cases": all_cases, }, "sampling_config": self.sampling_config @@ -278,6 +293,7 @@ class EvalState: incorrect_count = len(completed) - correct_count pending_count = len(tasks_to_save) - len(completed) accuracy = correct_count / len(completed) * 100 if completed else 0.0 + ci_lower, ci_upper = wilson_interval(correct_count, len(completed)) if completed else (0.0, 1.0) sampling_parts = [] for k, v in self.sampling_config.items(): @@ -378,7 +394,7 @@ class EvalState: