mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-11 19:44:06 +00:00
eval : add Wilson score confidence interval to results
Compute 95% CI on-the-fly from completed cases. Displayed in terminal output, HTML report, and JSON state.
This commit is contained in:
@@ -16,6 +16,17 @@ from typing import Dict, List, Optional, Any, Tuple
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
from math import sqrt
|
||||
|
||||
def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, float]:
|
||||
"""Wilson score confidence interval for a proportion."""
|
||||
if total == 0:
|
||||
return (0.0, 1.0)
|
||||
p = correct / total
|
||||
z2 = z * z / total
|
||||
center = (p + z2 / 2) / (1 + z2)
|
||||
margin = z * sqrt((p * (1 - p) + z2 / 4) / total) / (1 + z2)
|
||||
return (center - margin, center + margin)
|
||||
|
||||
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -227,8 +238,9 @@ class EvalState:
|
||||
print(f"Results: 0/0 correct (0.0%)")
|
||||
print(f"{'='*60}")
|
||||
else:
|
||||
ci_lower, ci_upper = self.accuracy_ci()
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%)")
|
||||
print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]")
|
||||
print(f"{'='*60}")
|
||||
|
||||
def dump(self):
|
||||
@@ -253,6 +265,7 @@ class EvalState:
|
||||
"reasoning_content": None
|
||||
}
|
||||
|
||||
ci_lower, ci_upper = self.accuracy_ci()
|
||||
data = {
|
||||
"id": self.dataset_type,
|
||||
"tasks": [tid for _, tid in tasks_to_save],
|
||||
@@ -260,6 +273,8 @@ class EvalState:
|
||||
"total": self.total,
|
||||
"correct": self.correct,
|
||||
"total_time": self.total_time,
|
||||
"ci_lower": ci_lower,
|
||||
"ci_upper": ci_upper,
|
||||
"cases": all_cases,
|
||||
},
|
||||
"sampling_config": self.sampling_config
|
||||
@@ -278,6 +293,7 @@ class EvalState:
|
||||
incorrect_count = len(completed) - correct_count
|
||||
pending_count = len(tasks_to_save) - len(completed)
|
||||
accuracy = correct_count / len(completed) * 100 if completed else 0.0
|
||||
ci_lower, ci_upper = wilson_interval(correct_count, len(completed)) if completed else (0.0, 1.0)
|
||||
|
||||
sampling_parts = []
|
||||
for k, v in self.sampling_config.items():
|
||||
@@ -378,7 +394,7 @@ class EvalState:
|
||||
<tr><td>Correct</td><td class="correct">{correct_count}</td></tr>
|
||||
<tr><td>Incorrect</td><td class="incorrect">{incorrect_count}</td></tr>
|
||||
<tr><td>Pending</td><td class="pending">{pending_count}</td></tr>
|
||||
<tr><td>Accuracy</td><td>{accuracy:.1f}%</td></tr>
|
||||
<tr><td>Accuracy</td><td>{accuracy:.1f}% [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]</td></tr>
|
||||
<tr><td>Total Time</td><td>{self.total_time:.1f}s</td></tr>
|
||||
<tr><td>Sampling</td><td>{sampling_str}</td></tr>
|
||||
</table>
|
||||
@@ -510,10 +526,19 @@ class EvalState:
|
||||
print(f"Results: 0/0 correct (0.0%)")
|
||||
print(f"{'='*60}")
|
||||
else:
|
||||
ci_lower, ci_upper = self.accuracy_ci()
|
||||
print(f"{'='*60}")
|
||||
print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)")
|
||||
print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]")
|
||||
print(f"{'='*60}")
|
||||
|
||||
def accuracy_ci(self) -> Tuple[float, float]:
|
||||
"""Compute Wilson score confidence interval from completed cases."""
|
||||
cases = self.task_states.get("cases", {})
|
||||
completed = {tid: c for tid, c in cases.items() if c.get("status") == "ok"}
|
||||
correct = sum(1 for c in completed.values() if c.get("correct", False))
|
||||
total = len(completed)
|
||||
return wilson_interval(correct, total)
|
||||
|
||||
def normalize_number(s: str) -> Optional[int]:
|
||||
match = re.match(r"\d+", s) # match digits from the start
|
||||
if not match:
|
||||
|
||||
Reference in New Issue
Block a user