This commit is contained in:
Georgi Gerganov
2026-02-16 23:02:45 +02:00
parent 7f049860b4
commit c0c3e428dd

View File

@@ -85,7 +85,7 @@ class BaseDataset(ABC):
pass
@abstractmethod
def get_question_str(self, question: Dict) -> str:
def get_question_text(self, question: Dict) -> str:
pass
@abstractmethod
@@ -102,11 +102,12 @@ class BaseDataset(ABC):
@dataclass
class TaskState:
case_id: str
task_id: str
prompt: str
gold: str
result: Optional[str] = None
extracted: Optional[str] = None
expected: str
question_text: str = ""
response: Optional[str] = None
answer: Optional[str] = None
grader_log: Dict[str, Any] = field(default_factory=dict)
correct: bool = False
status: str = "pending"
@@ -171,18 +172,18 @@ class EvalState:
if self.dataset is None:
raise ValueError("Dataset not loaded.")
question = self.dataset.get_question(index)
question_str = self.dataset.get_question_str(question)
question_text = self.dataset.get_question_text(question)
prompt = self.dataset.get_prompt(question)
gold = self.dataset.get_answer(question)
return question_str, prompt, gold
expected = self.dataset.get_answer(question)
return question_text, prompt, expected
def add_result(
self,
task_id: str,
prompt: str,
gold: str,
result: Optional[str],
extracted: Optional[str],
expected: str,
response: Optional[str],
answer: Optional[str],
grader_log: Dict[str, Any],
correct: bool,
status: str,
@@ -193,11 +194,11 @@ class EvalState:
self.task_states["cases"] = {}
self.task_states["cases"][task_id] = {
"case_id": task_id,
"task_id": task_id,
"prompt": prompt,
"gold": gold,
"result": result,
"extracted": extracted,
"expected": expected,
"response": response,
"answer": answer,
"grader_log": grader_log,
"correct": correct,
"status": status,
@@ -205,22 +206,19 @@ class EvalState:
"reasoning_content": reasoning_content
}
if correct:
self.correct += 1
else:
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
def print_progress(self, task_state: TaskState, total_tasks: int, correct_count: int = 0):
extracted_display = task_state.extracted if task_state.extracted else "N/A"
answer_display = task_state.answer if task_state.answer else "N/A"
tokens_display = str(task_state.tokens) if task_state.tokens is not None else "N/A"
success_ratio = correct_count / self.processed if self.processed > 0 else 0.0
first_line = task_state.prompt.split('\n')[0]
truncated_prompt = first_line[:43]
first_line = task_state.question_text.split('\n')[0]
truncated_question = first_line[:43]
if len(first_line) > 43:
truncated_prompt += "..."
truncated_question += "..."
else:
truncated_prompt = truncated_prompt.ljust(43) + "..."
print(f"{self.processed:3}/{total_tasks:3} {task_state.case_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {task_state.gold:<10} {extracted_display:<10} {tokens_display:<6} {'' if task_state.correct else ''} [{correct_count:3}/{self.processed:3}, {success_ratio:.3f}]")
truncated_question = truncated_question.ljust(43) + "..."
print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {answer_display:<10} {tokens_display:<6} {'' if task_state.correct else ''} [{correct_count:3}/{self.processed:3}, {success_ratio:.3f}]")
def print_summary(self):
if self.total == 0:
@@ -236,16 +234,17 @@ class EvalState:
tasks_to_save = self.all_tasks if self.all_tasks else self.tasks
all_cases = {}
for i, task_id in tasks_to_save:
question, prompt, gold = self.get_case(i)
question_text, prompt, expected = self.get_case(i)
if task_id in self.task_states.get("cases", {}):
all_cases[task_id] = self.task_states["cases"][task_id]
else:
all_cases[task_id] = {
"case_id": task_id,
"task_id": task_id,
"prompt": prompt,
"gold": gold,
"result": None,
"extracted": None,
"expected": expected,
"question_text": question_text,
"response": None,
"answer": None,
"grader_log": {},
"correct": False,
"status": "pending",
@@ -288,10 +287,10 @@ class EvalState:
for i, task_id in tasks_to_save:
case = cases.get(task_id, {})
status = case.get("status", "pending")
gold = case.get("gold", "")
extracted = case.get("extracted", "") if status == "ok" else ""
expected = case.get("expected", "")
answer = case.get("answer", "") if status == "ok" else ""
is_correct = case.get("correct", False) if status == "ok" else False
result = case.get("result", "") or ""
response = case.get("response", "") or ""
prompt = case.get("prompt", "") or ""
grader_log = case.get("grader_log", {})
@@ -309,7 +308,7 @@ class EvalState:
tokens_str = str(tokens) if tokens is not None else ""
reasoning_content = case.get("reasoning_content", "") or ""
result_escaped = self._escape_html(result)
response_escaped = self._escape_html(response)
prompt_escaped = self._escape_html(prompt)
reasoning_escaped = self._escape_html(reasoning_content)
grader_log_str = self._escape_html(json.dumps(grader_log, indent=2))
@@ -317,8 +316,8 @@ class EvalState:
rows.append(f"""<tr class="task-row" onclick="toggleDetails('{task_id}')">
<td>{task_id}</td>
<td class="{status_class}">{status_text}</td>
<td>{self._escape_html(gold)}</td>
<td>{self._escape_html(extracted)}</td>
<td>{self._escape_html(expected)}</td>
<td>{self._escape_html(answer)}</td>
<td>{tokens_str}</td>
</tr>
<tr id="details-{task_id}" class="details-row">
@@ -328,8 +327,8 @@ class EvalState:
<pre>{prompt_escaped}</pre>
<h4 onclick="toggleReasoning('{task_id}')" style="cursor:pointer">Reasoning &#9654;</h4>
<pre id="reasoning-{task_id}" style="display:none">{reasoning_escaped}</pre>
<h4>Result</h4>
<pre>{result_escaped}</pre>
<h4>Response</h4>
<pre>{response_escaped}</pre>
<h4>Grader Log</h4>
<pre>{grader_log_str}</pre>
</div>
@@ -478,12 +477,12 @@ class EvalState:
tasks_to_show = self.all_tasks if self.all_tasks else self.tasks
print()
print("Tasks:")
print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Tokens Status")
print(" Task ID Dataset Prompt (first 40 chars) Expected Answer Tokens Status")
for i, task_id in tasks_to_show:
question, prompt, gold = self.get_case(i)
question, prompt, expected = self.get_case(i)
case = cases.get(task_id, {})
status = case.get("status", "pending")
extracted = case.get("extracted", "N/A") if status == "ok" else "N/A"
answer = case.get("answer", "N/A") if status == "ok" else "N/A"
tokens = case.get("tokens")
tokens_str = str(tokens) if tokens is not None else "N/A"
is_correct = case.get("correct", False) if status == "ok" else False
@@ -494,7 +493,7 @@ class EvalState:
question_trunc += "..."
else:
question_trunc = question_trunc.ljust(43) + "..."
print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {gold:<10} {extracted:<10} {tokens_str:<6} {symbol}{status}")
print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {symbol}{status}")
print()
def print_existing_summary(self):
@@ -546,7 +545,7 @@ class AimeDataset(BaseDataset):
"""Get question by index"""
return self.questions[index]
def get_question_str(self, question: Dict) -> str:
def get_question_text(self, question: Dict) -> str:
"""Get question string"""
return question["problem"] if "problem" in question else question["question"]
@@ -560,7 +559,7 @@ class AimeDataset(BaseDataset):
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
question=self.get_question_str(question),
question=self.get_question_text(question),
)
class Aime2025Dataset(BaseDataset):
@@ -608,7 +607,7 @@ class Aime2025Dataset(BaseDataset):
"""Get question by index"""
return self.questions[index]
def get_question_str(self, question: Dict) -> str:
def get_question_text(self, question: Dict) -> str:
"""Get question string"""
return question["question"]
@@ -622,7 +621,7 @@ class Aime2025Dataset(BaseDataset):
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
return TEMPLATE_REGISTRY["aime2025"].format(
question=self.get_question_str(question),
question=self.get_question_text(question),
)
class Gsm8kDataset(BaseDataset):
@@ -665,7 +664,7 @@ class Gsm8kDataset(BaseDataset):
"""Get question by index"""
return self.questions[index]
def get_question_str(self, question: Dict) -> str:
def get_question_text(self, question: Dict) -> str:
"""Get question string"""
return question["problem"] if "problem" in question else question["question"]
@@ -682,7 +681,7 @@ class Gsm8kDataset(BaseDataset):
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
question=self.get_question_str(question),
question=self.get_question_text(question),
)
class GpqaDataset(BaseDataset):
@@ -737,7 +736,7 @@ class GpqaDataset(BaseDataset):
"""Get question by index"""
return self.questions[index]
def get_question_str(self, question: Dict) -> str:
def get_question_text(self, question: Dict) -> str:
"""Get question string"""
return question["Question"]
@@ -748,7 +747,7 @@ class GpqaDataset(BaseDataset):
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
return TEMPLATE_REGISTRY["gpqa"].format(
Question=self.get_question_str(question),
Question=self.get_question_text(question),
A=question["shuffled_answers"][0],
B=question["shuffled_answers"][1],
C=question["shuffled_answers"][2],
@@ -799,18 +798,18 @@ class Grader:
for match in reversed(matches):
if isinstance(match, tuple):
match = match[0] if match[0] else match[1]
extracted = match.strip()
if extracted:
return extracted
answer = match.strip()
if answer:
return answer
return None
def _grade_regex(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
"""Grade using regex pattern matching"""
extracted = self._extract_answer_regex(pred)
if extracted is None:
answer = self._extract_answer_regex(pred)
if answer is None:
return False, None
is_correct = extracted.strip() == gold.strip()
return is_correct, extracted
is_correct = answer.strip() == gold.strip()
return is_correct, answer
def _grade_cli(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
"""Grade using external CLI script"""
@@ -829,8 +828,8 @@ class Grader:
timeout=30
)
is_correct = result.returncode == 0
extracted = pred if is_correct else None
return is_correct, extracted
answer = pred if is_correct else None
return is_correct, answer
except subprocess.TimeoutExpired:
return False, None
except Exception as e:
@@ -872,9 +871,9 @@ Please provide only the extracted answer, nothing else. If there is no clear ans
try:
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
extracted = response.json()["choices"][0]["message"]["content"].strip()
is_correct = extracted.strip().lower() == gold.strip().lower()
return is_correct, extracted
answer = response.json()["choices"][0]["message"]["content"].strip()
is_correct = answer.strip().lower() == gold.strip().lower()
return is_correct, answer
except Exception as e:
return False, None
@@ -934,30 +933,31 @@ class Processor:
return result, tokens, finish_reason
def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState:
question, prompt, gold = eval_state.get_case(i)
question_text, prompt, expected = eval_state.get_case(i)
task_state = TaskState(
case_id=task_id,
task_id=task_id,
prompt=prompt,
gold=gold
expected=expected,
question_text=question_text
)
try:
response, tokens, finish_reason = self._make_request(eval_state, prompt)
result = response["choices"][0]["message"]["content"]
reasoning_content = response["choices"][0].get("message", {}).get("reasoning_content")
task_state.result = result
task_state.response = result
task_state.tokens = tokens
task_state.reasoning_content = reasoning_content
if finish_reason != "stop":
task_state.status = f"error: finish_reason={finish_reason}"
eval_state.add_result(task_id, prompt, gold, result, None, {"finish_reason": finish_reason}, False, task_state.status, tokens, reasoning_content)
eval_state.add_result(task_id, prompt, expected, result, None, {"finish_reason": finish_reason}, False, task_state.status, tokens, reasoning_content)
eval_state.dump()
return task_state
result_truncated = self.grader._truncate_response(result, max_lines=10)
is_correct, extracted = self.grader.grade(gold, result_truncated, prompt)
is_correct, answer = self.grader.grade(expected, result_truncated, prompt)
grader_log = {
"pred": result_truncated,
@@ -967,11 +967,11 @@ class Processor:
grader_log["pattern"] = self.grader.pattern
task_state.correct = is_correct
task_state.extracted = extracted
task_state.answer = answer
task_state.grader_log = grader_log
task_state.status = "ok"
eval_state.add_result(task_id, prompt, gold, result, extracted, grader_log, is_correct, "ok", tokens, reasoning_content)
eval_state.add_result(task_id, prompt, expected, result, answer, grader_log, is_correct, "ok", tokens, reasoning_content)
eval_state.dump()
@@ -1009,11 +1009,11 @@ class Processor:
if verbose:
print(f"\nCase {eval_state.processed}: {task_state.correct}")
print(f" Gold: {task_state.gold}")
if task_state.result:
print(f" Result: {task_state.result}")
if task_state.extracted:
print(f" Extracted: {task_state.extracted}")
print(f" Expected: {task_state.expected}")
if task_state.response:
print(f" Response: {task_state.response}")
if task_state.answer:
print(f" Answer: {task_state.answer}")
print(f" Status: {task_state.status}")
eval_state.print_summary()