mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-03-17 16:44:07 +00:00
refactor
This commit is contained in:
@@ -85,7 +85,7 @@ class BaseDataset(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_question_str(self, question: Dict) -> str:
|
||||
def get_question_text(self, question: Dict) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -102,11 +102,12 @@ class BaseDataset(ABC):
|
||||
|
||||
@dataclass
|
||||
class TaskState:
|
||||
case_id: str
|
||||
task_id: str
|
||||
prompt: str
|
||||
gold: str
|
||||
result: Optional[str] = None
|
||||
extracted: Optional[str] = None
|
||||
expected: str
|
||||
question_text: str = ""
|
||||
response: Optional[str] = None
|
||||
answer: Optional[str] = None
|
||||
grader_log: Dict[str, Any] = field(default_factory=dict)
|
||||
correct: bool = False
|
||||
status: str = "pending"
|
||||
@@ -171,18 +172,18 @@ class EvalState:
|
||||
if self.dataset is None:
|
||||
raise ValueError("Dataset not loaded.")
|
||||
question = self.dataset.get_question(index)
|
||||
question_str = self.dataset.get_question_str(question)
|
||||
question_text = self.dataset.get_question_text(question)
|
||||
prompt = self.dataset.get_prompt(question)
|
||||
gold = self.dataset.get_answer(question)
|
||||
return question_str, prompt, gold
|
||||
expected = self.dataset.get_answer(question)
|
||||
return question_text, prompt, expected
|
||||
|
||||
def add_result(
|
||||
self,
|
||||
task_id: str,
|
||||
prompt: str,
|
||||
gold: str,
|
||||
result: Optional[str],
|
||||
extracted: Optional[str],
|
||||
expected: str,
|
||||
response: Optional[str],
|
||||
answer: Optional[str],
|
||||
grader_log: Dict[str, Any],
|
||||
correct: bool,
|
||||
status: str,
|
||||
@@ -193,11 +194,11 @@ class EvalState:
|
||||
self.task_states["cases"] = {}
|
||||
|
||||
self.task_states["cases"][task_id] = {
|
||||
"case_id": task_id,
|
||||
"task_id": task_id,
|
||||
"prompt": prompt,
|
||||
"gold": gold,
|
||||
"result": result,
|
||||
"extracted": extracted,
|
||||
"expected": expected,
|
||||
"response": response,
|
||||
"answer": answer,
|
||||
"grader_log": grader_log,
|
||||
"correct": correct,
|
||||
"status": status,
|
||||
@@ -205,22 +206,19 @@ class EvalState:
|
||||
"reasoning_content": reasoning_content
|
||||
}
|
||||
|
||||
if correct:
|
||||
self.correct += 1
|
||||
else:
|
||||
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
|
||||
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
|
||||
|
||||
def print_progress(self, task_state: TaskState, total_tasks: int, correct_count: int = 0):
|
||||
extracted_display = task_state.extracted if task_state.extracted else "N/A"
|
||||
answer_display = task_state.answer if task_state.answer else "N/A"
|
||||
tokens_display = str(task_state.tokens) if task_state.tokens is not None else "N/A"
|
||||
success_ratio = correct_count / self.processed if self.processed > 0 else 0.0
|
||||
first_line = task_state.prompt.split('\n')[0]
|
||||
truncated_prompt = first_line[:43]
|
||||
first_line = task_state.question_text.split('\n')[0]
|
||||
truncated_question = first_line[:43]
|
||||
if len(first_line) > 43:
|
||||
truncated_prompt += "..."
|
||||
truncated_question += "..."
|
||||
else:
|
||||
truncated_prompt = truncated_prompt.ljust(43) + "..."
|
||||
print(f"{self.processed:3}/{total_tasks:3} {task_state.case_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {task_state.gold:<10} {extracted_display:<10} {tokens_display:<6} {'✓' if task_state.correct else '✗'} [{correct_count:3}/{self.processed:3}, {success_ratio:.3f}]")
|
||||
truncated_question = truncated_question.ljust(43) + "..."
|
||||
print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {answer_display:<10} {tokens_display:<6} {'✓' if task_state.correct else '✗'} [{correct_count:3}/{self.processed:3}, {success_ratio:.3f}]")
|
||||
|
||||
def print_summary(self):
|
||||
if self.total == 0:
|
||||
@@ -236,16 +234,17 @@ class EvalState:
|
||||
tasks_to_save = self.all_tasks if self.all_tasks else self.tasks
|
||||
all_cases = {}
|
||||
for i, task_id in tasks_to_save:
|
||||
question, prompt, gold = self.get_case(i)
|
||||
question_text, prompt, expected = self.get_case(i)
|
||||
if task_id in self.task_states.get("cases", {}):
|
||||
all_cases[task_id] = self.task_states["cases"][task_id]
|
||||
else:
|
||||
all_cases[task_id] = {
|
||||
"case_id": task_id,
|
||||
"task_id": task_id,
|
||||
"prompt": prompt,
|
||||
"gold": gold,
|
||||
"result": None,
|
||||
"extracted": None,
|
||||
"expected": expected,
|
||||
"question_text": question_text,
|
||||
"response": None,
|
||||
"answer": None,
|
||||
"grader_log": {},
|
||||
"correct": False,
|
||||
"status": "pending",
|
||||
@@ -288,10 +287,10 @@ class EvalState:
|
||||
for i, task_id in tasks_to_save:
|
||||
case = cases.get(task_id, {})
|
||||
status = case.get("status", "pending")
|
||||
gold = case.get("gold", "")
|
||||
extracted = case.get("extracted", "") if status == "ok" else ""
|
||||
expected = case.get("expected", "")
|
||||
answer = case.get("answer", "") if status == "ok" else ""
|
||||
is_correct = case.get("correct", False) if status == "ok" else False
|
||||
result = case.get("result", "") or ""
|
||||
response = case.get("response", "") or ""
|
||||
prompt = case.get("prompt", "") or ""
|
||||
grader_log = case.get("grader_log", {})
|
||||
|
||||
@@ -309,7 +308,7 @@ class EvalState:
|
||||
tokens_str = str(tokens) if tokens is not None else ""
|
||||
reasoning_content = case.get("reasoning_content", "") or ""
|
||||
|
||||
result_escaped = self._escape_html(result)
|
||||
response_escaped = self._escape_html(response)
|
||||
prompt_escaped = self._escape_html(prompt)
|
||||
reasoning_escaped = self._escape_html(reasoning_content)
|
||||
grader_log_str = self._escape_html(json.dumps(grader_log, indent=2))
|
||||
@@ -317,8 +316,8 @@ class EvalState:
|
||||
rows.append(f"""<tr class="task-row" onclick="toggleDetails('{task_id}')">
|
||||
<td>{task_id}</td>
|
||||
<td class="{status_class}">{status_text}</td>
|
||||
<td>{self._escape_html(gold)}</td>
|
||||
<td>{self._escape_html(extracted)}</td>
|
||||
<td>{self._escape_html(expected)}</td>
|
||||
<td>{self._escape_html(answer)}</td>
|
||||
<td>{tokens_str}</td>
|
||||
</tr>
|
||||
<tr id="details-{task_id}" class="details-row">
|
||||
@@ -328,8 +327,8 @@ class EvalState:
|
||||
<pre>{prompt_escaped}</pre>
|
||||
<h4 onclick="toggleReasoning('{task_id}')" style="cursor:pointer">Reasoning ▶</h4>
|
||||
<pre id="reasoning-{task_id}" style="display:none">{reasoning_escaped}</pre>
|
||||
<h4>Result</h4>
|
||||
<pre>{result_escaped}</pre>
|
||||
<h4>Response</h4>
|
||||
<pre>{response_escaped}</pre>
|
||||
<h4>Grader Log</h4>
|
||||
<pre>{grader_log_str}</pre>
|
||||
</div>
|
||||
@@ -478,12 +477,12 @@ class EvalState:
|
||||
tasks_to_show = self.all_tasks if self.all_tasks else self.tasks
|
||||
print()
|
||||
print("Tasks:")
|
||||
print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Tokens Status")
|
||||
print(" Task ID Dataset Prompt (first 40 chars) Expected Answer Tokens Status")
|
||||
for i, task_id in tasks_to_show:
|
||||
question, prompt, gold = self.get_case(i)
|
||||
question, prompt, expected = self.get_case(i)
|
||||
case = cases.get(task_id, {})
|
||||
status = case.get("status", "pending")
|
||||
extracted = case.get("extracted", "N/A") if status == "ok" else "N/A"
|
||||
answer = case.get("answer", "N/A") if status == "ok" else "N/A"
|
||||
tokens = case.get("tokens")
|
||||
tokens_str = str(tokens) if tokens is not None else "N/A"
|
||||
is_correct = case.get("correct", False) if status == "ok" else False
|
||||
@@ -494,7 +493,7 @@ class EvalState:
|
||||
question_trunc += "..."
|
||||
else:
|
||||
question_trunc = question_trunc.ljust(43) + "..."
|
||||
print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {gold:<10} {extracted:<10} {tokens_str:<6} {symbol}{status}")
|
||||
print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {symbol}{status}")
|
||||
print()
|
||||
|
||||
def print_existing_summary(self):
|
||||
@@ -546,7 +545,7 @@ class AimeDataset(BaseDataset):
|
||||
"""Get question by index"""
|
||||
return self.questions[index]
|
||||
|
||||
def get_question_str(self, question: Dict) -> str:
|
||||
def get_question_text(self, question: Dict) -> str:
|
||||
"""Get question string"""
|
||||
return question["problem"] if "problem" in question else question["question"]
|
||||
|
||||
@@ -560,7 +559,7 @@ class AimeDataset(BaseDataset):
|
||||
def get_prompt(self, question: Dict) -> str:
|
||||
"""Get formatted prompt for the question"""
|
||||
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
|
||||
question=self.get_question_str(question),
|
||||
question=self.get_question_text(question),
|
||||
)
|
||||
|
||||
class Aime2025Dataset(BaseDataset):
|
||||
@@ -608,7 +607,7 @@ class Aime2025Dataset(BaseDataset):
|
||||
"""Get question by index"""
|
||||
return self.questions[index]
|
||||
|
||||
def get_question_str(self, question: Dict) -> str:
|
||||
def get_question_text(self, question: Dict) -> str:
|
||||
"""Get question string"""
|
||||
return question["question"]
|
||||
|
||||
@@ -622,7 +621,7 @@ class Aime2025Dataset(BaseDataset):
|
||||
def get_prompt(self, question: Dict) -> str:
|
||||
"""Get formatted prompt for the question"""
|
||||
return TEMPLATE_REGISTRY["aime2025"].format(
|
||||
question=self.get_question_str(question),
|
||||
question=self.get_question_text(question),
|
||||
)
|
||||
|
||||
class Gsm8kDataset(BaseDataset):
|
||||
@@ -665,7 +664,7 @@ class Gsm8kDataset(BaseDataset):
|
||||
"""Get question by index"""
|
||||
return self.questions[index]
|
||||
|
||||
def get_question_str(self, question: Dict) -> str:
|
||||
def get_question_text(self, question: Dict) -> str:
|
||||
"""Get question string"""
|
||||
return question["problem"] if "problem" in question else question["question"]
|
||||
|
||||
@@ -682,7 +681,7 @@ class Gsm8kDataset(BaseDataset):
|
||||
def get_prompt(self, question: Dict) -> str:
|
||||
"""Get formatted prompt for the question"""
|
||||
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
|
||||
question=self.get_question_str(question),
|
||||
question=self.get_question_text(question),
|
||||
)
|
||||
|
||||
class GpqaDataset(BaseDataset):
|
||||
@@ -737,7 +736,7 @@ class GpqaDataset(BaseDataset):
|
||||
"""Get question by index"""
|
||||
return self.questions[index]
|
||||
|
||||
def get_question_str(self, question: Dict) -> str:
|
||||
def get_question_text(self, question: Dict) -> str:
|
||||
"""Get question string"""
|
||||
return question["Question"]
|
||||
|
||||
@@ -748,7 +747,7 @@ class GpqaDataset(BaseDataset):
|
||||
def get_prompt(self, question: Dict) -> str:
|
||||
"""Get formatted prompt for the question"""
|
||||
return TEMPLATE_REGISTRY["gpqa"].format(
|
||||
Question=self.get_question_str(question),
|
||||
Question=self.get_question_text(question),
|
||||
A=question["shuffled_answers"][0],
|
||||
B=question["shuffled_answers"][1],
|
||||
C=question["shuffled_answers"][2],
|
||||
@@ -799,18 +798,18 @@ class Grader:
|
||||
for match in reversed(matches):
|
||||
if isinstance(match, tuple):
|
||||
match = match[0] if match[0] else match[1]
|
||||
extracted = match.strip()
|
||||
if extracted:
|
||||
return extracted
|
||||
answer = match.strip()
|
||||
if answer:
|
||||
return answer
|
||||
return None
|
||||
|
||||
def _grade_regex(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Grade using regex pattern matching"""
|
||||
extracted = self._extract_answer_regex(pred)
|
||||
if extracted is None:
|
||||
answer = self._extract_answer_regex(pred)
|
||||
if answer is None:
|
||||
return False, None
|
||||
is_correct = extracted.strip() == gold.strip()
|
||||
return is_correct, extracted
|
||||
is_correct = answer.strip() == gold.strip()
|
||||
return is_correct, answer
|
||||
|
||||
def _grade_cli(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Grade using external CLI script"""
|
||||
@@ -829,8 +828,8 @@ class Grader:
|
||||
timeout=30
|
||||
)
|
||||
is_correct = result.returncode == 0
|
||||
extracted = pred if is_correct else None
|
||||
return is_correct, extracted
|
||||
answer = pred if is_correct else None
|
||||
return is_correct, answer
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, None
|
||||
except Exception as e:
|
||||
@@ -872,9 +871,9 @@ Please provide only the extracted answer, nothing else. If there is no clear ans
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
extracted = response.json()["choices"][0]["message"]["content"].strip()
|
||||
is_correct = extracted.strip().lower() == gold.strip().lower()
|
||||
return is_correct, extracted
|
||||
answer = response.json()["choices"][0]["message"]["content"].strip()
|
||||
is_correct = answer.strip().lower() == gold.strip().lower()
|
||||
return is_correct, answer
|
||||
except Exception as e:
|
||||
return False, None
|
||||
|
||||
@@ -934,30 +933,31 @@ class Processor:
|
||||
return result, tokens, finish_reason
|
||||
|
||||
def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState:
|
||||
question, prompt, gold = eval_state.get_case(i)
|
||||
question_text, prompt, expected = eval_state.get_case(i)
|
||||
|
||||
task_state = TaskState(
|
||||
case_id=task_id,
|
||||
task_id=task_id,
|
||||
prompt=prompt,
|
||||
gold=gold
|
||||
expected=expected,
|
||||
question_text=question_text
|
||||
)
|
||||
|
||||
try:
|
||||
response, tokens, finish_reason = self._make_request(eval_state, prompt)
|
||||
result = response["choices"][0]["message"]["content"]
|
||||
reasoning_content = response["choices"][0].get("message", {}).get("reasoning_content")
|
||||
task_state.result = result
|
||||
task_state.response = result
|
||||
task_state.tokens = tokens
|
||||
task_state.reasoning_content = reasoning_content
|
||||
|
||||
if finish_reason != "stop":
|
||||
task_state.status = f"error: finish_reason={finish_reason}"
|
||||
eval_state.add_result(task_id, prompt, gold, result, None, {"finish_reason": finish_reason}, False, task_state.status, tokens, reasoning_content)
|
||||
eval_state.add_result(task_id, prompt, expected, result, None, {"finish_reason": finish_reason}, False, task_state.status, tokens, reasoning_content)
|
||||
eval_state.dump()
|
||||
return task_state
|
||||
|
||||
result_truncated = self.grader._truncate_response(result, max_lines=10)
|
||||
is_correct, extracted = self.grader.grade(gold, result_truncated, prompt)
|
||||
is_correct, answer = self.grader.grade(expected, result_truncated, prompt)
|
||||
|
||||
grader_log = {
|
||||
"pred": result_truncated,
|
||||
@@ -967,11 +967,11 @@ class Processor:
|
||||
grader_log["pattern"] = self.grader.pattern
|
||||
|
||||
task_state.correct = is_correct
|
||||
task_state.extracted = extracted
|
||||
task_state.answer = answer
|
||||
task_state.grader_log = grader_log
|
||||
task_state.status = "ok"
|
||||
|
||||
eval_state.add_result(task_id, prompt, gold, result, extracted, grader_log, is_correct, "ok", tokens, reasoning_content)
|
||||
eval_state.add_result(task_id, prompt, expected, result, answer, grader_log, is_correct, "ok", tokens, reasoning_content)
|
||||
|
||||
eval_state.dump()
|
||||
|
||||
@@ -1009,11 +1009,11 @@ class Processor:
|
||||
|
||||
if verbose:
|
||||
print(f"\nCase {eval_state.processed}: {task_state.correct}")
|
||||
print(f" Gold: {task_state.gold}")
|
||||
if task_state.result:
|
||||
print(f" Result: {task_state.result}")
|
||||
if task_state.extracted:
|
||||
print(f" Extracted: {task_state.extracted}")
|
||||
print(f" Expected: {task_state.expected}")
|
||||
if task_state.response:
|
||||
print(f" Response: {task_state.response}")
|
||||
if task_state.answer:
|
||||
print(f" Answer: {task_state.answer}")
|
||||
print(f" Status: {task_state.status}")
|
||||
|
||||
eval_state.print_summary()
|
||||
|
||||
Reference in New Issue
Block a user