mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ef9c0c3699
commit
afab28dfa9
37 changed files with 4868 additions and 4052 deletions
|
|
@ -28,6 +28,15 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from eval_helpers import (
|
||||
build_mcqa_fallback_patterns,
|
||||
create_system_content,
|
||||
extract_letter_from_answer_tag,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
save_eval_results,
|
||||
validate_thinking_format,
|
||||
)
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
|
|
@ -37,15 +46,6 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
)
|
||||
from eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
validate_thinking_format,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
create_system_content,
|
||||
save_eval_results,
|
||||
build_mcqa_fallback_patterns,
|
||||
)
|
||||
|
||||
|
||||
class WinoGrandeEvalConfig(BaseEnvConfig):
|
||||
|
|
@ -124,10 +124,10 @@ class WinoGrandeEvalConfig(BaseEnvConfig):
|
|||
class WinoGrandeEvalEnv(BaseEnv):
|
||||
"""
|
||||
WinoGrande Evaluation Environment for Atropos.
|
||||
|
||||
|
||||
Evaluates models on commonsense reasoning with binary choice pronoun resolution.
|
||||
"""
|
||||
|
||||
|
||||
name = "winogrande_eval"
|
||||
env_config_cls = WinoGrandeEvalConfig
|
||||
|
||||
|
|
@ -141,10 +141,10 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: WinoGrandeEvalConfig = config
|
||||
self.eval_metrics = []
|
||||
|
||||
|
||||
# Pre-build fallback patterns for 2-choice (A/B)
|
||||
self._fallback_patterns = build_mcqa_fallback_patterns(2)
|
||||
self._valid_letters = {'A', 'B'}
|
||||
self._valid_letters = {"A", "B"}
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[WinoGrandeEvalConfig, List[APIServerConfig]]:
|
||||
|
|
@ -178,8 +178,10 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
print(f" Evaluation split: {self.config.eval_split}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
if self.config.thinking_mode:
|
||||
print(f" Thinking prompt: {get_default_thinking_prompt(self.config.custom_thinking_prompt)[:80]}...")
|
||||
|
||||
print(
|
||||
f" Thinking prompt: {get_default_thinking_prompt(self.config.custom_thinking_prompt)[:80]}..."
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
self.dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
|
|
@ -187,22 +189,22 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
split=self.config.eval_split,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
|
||||
self.eval_items = list(self.dataset)
|
||||
print(f" Loaded {len(self.eval_items)} evaluation items")
|
||||
|
||||
def _format_prompt(self, item: Dict) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Format a WinoGrande item into a prompt.
|
||||
|
||||
|
||||
The sentence contains an underscore "_" that should be filled with one of two options.
|
||||
|
||||
|
||||
Returns the formatted prompt and list of choice texts.
|
||||
"""
|
||||
sentence = item['sentence']
|
||||
option1 = item['option1']
|
||||
option2 = item['option2']
|
||||
|
||||
sentence = item["sentence"]
|
||||
option1 = item["option1"]
|
||||
option2 = item["option2"]
|
||||
|
||||
# Build the question
|
||||
query = "The following is a fill-in-the-blank question about commonsense reasoning.\n\n"
|
||||
query += f"Sentence: {sentence}\n\n"
|
||||
|
|
@ -210,7 +212,7 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
query += f"A. {option1}\n"
|
||||
query += f"B. {option2}\n"
|
||||
query += "\nProvide your answer in <answer></answer> tags with only the letter (A or B)."
|
||||
|
||||
|
||||
return query, [option1, option2]
|
||||
|
||||
def _create_system_content(self) -> Optional[str]:
|
||||
|
|
@ -218,13 +220,15 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
return create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt
|
||||
self.config.custom_system_prompt,
|
||||
)
|
||||
|
||||
def _extract_answer(self, response: str, choices: List[str] = None) -> Tuple[Optional[str], str]:
|
||||
def _extract_answer(
|
||||
self, response: str, choices: List[str] = None
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Extract the answer letter from the model's response.
|
||||
|
||||
|
||||
Uses <answer> tags as primary method, with fallback patterns.
|
||||
"""
|
||||
# Get content after </think> if in thinking mode
|
||||
|
|
@ -236,31 +240,43 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
response_to_parse = response
|
||||
else:
|
||||
response_to_parse = response
|
||||
|
||||
|
||||
# Primary: Try <answer></answer> tags
|
||||
letter, method = extract_letter_from_answer_tag(
|
||||
response_to_parse,
|
||||
self._valid_letters,
|
||||
debug=self.config.full_debug,
|
||||
choices=choices
|
||||
choices=choices,
|
||||
)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
|
||||
# Fallback: Use regex patterns
|
||||
for priority, pattern, method_name in self._fallback_patterns:
|
||||
matches = pattern.findall(response_to_parse)
|
||||
if matches:
|
||||
match = matches[-1] if method_name in ["final_answer_is", "the_answer_is", "answer_colon", "answer_space"] else matches[0]
|
||||
match = (
|
||||
matches[-1]
|
||||
if method_name
|
||||
in [
|
||||
"final_answer_is",
|
||||
"the_answer_is",
|
||||
"answer_colon",
|
||||
"answer_space",
|
||||
]
|
||||
else matches[0]
|
||||
)
|
||||
if isinstance(match, tuple):
|
||||
match = match[0]
|
||||
letter = match.strip("()").upper()
|
||||
if letter in self._valid_letters:
|
||||
return letter, f"fallback_{method_name}"
|
||||
|
||||
|
||||
return None, "no_match"
|
||||
|
||||
async def _generate_with_retry(self, messages: List[Dict], item_id: str) -> Optional[str]:
|
||||
async def _generate_with_retry(
|
||||
self, messages: List[Dict], item_id: str
|
||||
) -> Optional[str]:
|
||||
"""Generate response with retry logic."""
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
|
|
@ -271,59 +287,67 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
api_params["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
response = await self.client.chat.completions.create(**api_params)
|
||||
|
||||
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content.strip()
|
||||
if len(content) >= self.config.min_response_length:
|
||||
return content
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f" Error on item {item_id} attempt {attempt + 1}: {e}")
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(self.config.retry_delay * (attempt + 1))
|
||||
|
||||
|
||||
return None
|
||||
|
||||
async def _evaluate_single_item(self, item: Dict, idx: int) -> Dict:
|
||||
"""Evaluate a single WinoGrande item."""
|
||||
# Format prompt
|
||||
prompt, choices = self._format_prompt(item)
|
||||
|
||||
|
||||
# Build messages
|
||||
messages = []
|
||||
system_content = self._create_system_content()
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
|
||||
# Generate response
|
||||
response = await self._generate_with_retry(messages, str(idx))
|
||||
|
||||
|
||||
if response is None:
|
||||
return {
|
||||
"index": idx,
|
||||
"is_correct": False,
|
||||
"extracted_answer": None,
|
||||
"gold_answer": ascii_uppercase[int(item['answer']) - 1] if item['answer'] != '' else None,
|
||||
"gold_answer": (
|
||||
ascii_uppercase[int(item["answer"]) - 1]
|
||||
if item["answer"] != ""
|
||||
else None
|
||||
),
|
||||
"extraction_method": "generation_failed",
|
||||
"error": "Failed to generate response",
|
||||
}
|
||||
|
||||
|
||||
# Extract answer
|
||||
extracted_answer, extraction_method = self._extract_answer(response, choices)
|
||||
|
||||
|
||||
# Determine gold answer (WinoGrande uses 1/2, we convert to A/B)
|
||||
gold_answer = None
|
||||
if item['answer'] != '':
|
||||
gold_idx = int(item['answer']) - 1 # Convert 1-indexed to 0-indexed
|
||||
if item["answer"] != "":
|
||||
gold_idx = int(item["answer"]) - 1 # Convert 1-indexed to 0-indexed
|
||||
gold_answer = ascii_uppercase[gold_idx]
|
||||
|
||||
|
||||
# Score
|
||||
is_correct = extracted_answer == gold_answer if extracted_answer and gold_answer else False
|
||||
|
||||
is_correct = (
|
||||
extracted_answer == gold_answer
|
||||
if extracted_answer and gold_answer
|
||||
else False
|
||||
)
|
||||
|
||||
result = {
|
||||
"index": idx,
|
||||
"is_correct": is_correct,
|
||||
|
|
@ -331,11 +355,11 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
"gold_answer": gold_answer,
|
||||
"extraction_method": extraction_method,
|
||||
}
|
||||
|
||||
|
||||
if self.config.full_debug:
|
||||
result["response"] = response
|
||||
result["prompt"] = prompt
|
||||
|
||||
|
||||
return result
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
|
|
@ -347,32 +371,32 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
print(f" Total questions: {len(self.eval_items)}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# Evaluate all items
|
||||
tasks = [
|
||||
self._evaluate_single_item(item, idx)
|
||||
for idx, item in enumerate(self.eval_items)
|
||||
]
|
||||
|
||||
|
||||
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating WinoGrande")
|
||||
|
||||
|
||||
# Calculate metrics
|
||||
valid_results = [r for r in results if r.get("gold_answer") is not None]
|
||||
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return
|
||||
|
||||
|
||||
correct = sum(1 for r in valid_results if r["is_correct"])
|
||||
total = len(valid_results)
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
|
||||
# Extraction method breakdown
|
||||
method_counts = {}
|
||||
for r in valid_results:
|
||||
method = r.get("extraction_method", "unknown")
|
||||
method_counts[method] = method_counts.get(method, 0) + 1
|
||||
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("WinoGrande Evaluation Results")
|
||||
|
|
@ -386,7 +410,7 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
for method, count in sorted(method_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {method}: {count} ({count/total:.1%})")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# Save results
|
||||
metrics = {
|
||||
"accuracy": accuracy,
|
||||
|
|
@ -395,17 +419,15 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
"subset": self.config.subset,
|
||||
"extraction_methods": method_counts,
|
||||
}
|
||||
|
||||
save_eval_results(
|
||||
self.config.data_dir_to_save_evals,
|
||||
metrics,
|
||||
results
|
||||
)
|
||||
|
||||
self.eval_metrics = [{
|
||||
"accuracy": accuracy,
|
||||
"total": total,
|
||||
}]
|
||||
|
||||
save_eval_results(self.config.data_dir_to_save_evals, metrics, results)
|
||||
|
||||
self.eval_metrics = [
|
||||
{
|
||||
"accuracy": accuracy,
|
||||
"total": total,
|
||||
}
|
||||
]
|
||||
|
||||
async def wandb_log(self, step: int):
|
||||
"""Log metrics to wandb."""
|
||||
|
|
@ -426,4 +448,3 @@ class WinoGrandeEvalEnv(BaseEnv):
|
|||
|
||||
if __name__ == "__main__":
|
||||
WinoGrandeEvalEnv.cli()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue