[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-12-24 10:48:20 +00:00
parent ef9c0c3699
commit afab28dfa9
37 changed files with 4868 additions and 4052 deletions

View file

@ -29,6 +29,15 @@ from typing import Any, Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from eval_helpers import (
build_mcqa_fallback_patterns,
create_system_content,
extract_letter_from_answer_tag,
extract_thinking_content,
get_default_thinking_prompt,
save_eval_results,
validate_thinking_format,
)
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
@ -38,15 +47,6 @@ from atroposlib.envs.base import (
BaseEnvConfig,
EvalHandlingEnum,
)
from eval_helpers import (
extract_letter_from_answer_tag,
validate_thinking_format,
extract_thinking_content,
get_default_thinking_prompt,
create_system_content,
save_eval_results,
build_mcqa_fallback_patterns,
)
class ARCEvalConfig(BaseEnvConfig):
@ -125,10 +125,10 @@ class ARCEvalConfig(BaseEnvConfig):
class ARCEvalEnv(BaseEnv):
"""
ARC Evaluation Environment for Atropos.
Evaluates models on grade-school science questions with multiple choice.
"""
name = "arc_eval"
env_config_cls = ARCEvalConfig
@ -142,7 +142,7 @@ class ARCEvalEnv(BaseEnv):
super().__init__(config, server_configs, slurm, testing)
self.config: ARCEvalConfig = config
self.eval_metrics = []
# Fallback patterns will be built after loading dataset (variable number of choices)
self._fallback_patterns = None
self._valid_letters = None
@ -179,8 +179,10 @@ class ARCEvalEnv(BaseEnv):
print(f" Evaluation split: {self.config.eval_split}")
print(f" Thinking mode: {self.config.thinking_mode}")
if self.config.thinking_mode:
print(f" Thinking prompt: {get_default_thinking_prompt(self.config.custom_thinking_prompt)[:80]}...")
print(
f" Thinking prompt: {get_default_thinking_prompt(self.config.custom_thinking_prompt)[:80]}..."
)
# Load dataset
self.dataset = load_dataset(
self.config.dataset_name,
@ -188,34 +190,34 @@ class ARCEvalEnv(BaseEnv):
split=self.config.eval_split,
trust_remote_code=True,
)
self.eval_items = list(self.dataset)
print(f" Loaded {len(self.eval_items)} evaluation items")
# Determine max number of choices (usually 4-5)
max_choices = max(len(item['choices']['text']) for item in self.eval_items)
max_choices = max(len(item["choices"]["text"]) for item in self.eval_items)
self._fallback_patterns = build_mcqa_fallback_patterns(max_choices)
self._valid_letters = set(ascii_uppercase[:max_choices])
def _format_prompt(self, item: Dict) -> Tuple[str, List[str]]:
"""
Format an ARC item into a prompt.
Returns the formatted prompt and list of choice texts.
"""
question = item['question']
choices_text = item['choices']['text']
choices_label = item['choices']['label']
question = item["question"]
choices_text = item["choices"]["text"]
choices_label = item["choices"]["label"]
# Build the question
query = "The following is a multiple choice science question.\n\n"
query += f"Question: {question}\n"
for label, text in zip(choices_label, choices_text):
query += f"{label}. {text}\n"
# Add answer instruction with <answer> tag format
query += "\nProvide your answer in <answer></answer> tags with only the letter of the correct choice."
return query, choices_text
def _create_system_content(self) -> Optional[str]:
@ -223,19 +225,19 @@ class ARCEvalEnv(BaseEnv):
return create_system_content(
self.config.thinking_mode,
self.config.custom_thinking_prompt,
self.config.custom_system_prompt
self.config.custom_system_prompt,
)
def _extract_answer(self, response: str, item: Dict) -> Tuple[Optional[str], str]:
"""
Extract the answer letter from the model's response.
Uses <answer> tags as primary method, with fallback patterns.
"""
# Get valid letters for this specific question
valid_letters = set(item['choices']['label'])
choices = item['choices']['text']
valid_letters = set(item["choices"]["label"])
choices = item["choices"]["text"]
# Get content after </think> if in thinking mode
if self.config.thinking_mode:
is_valid, content_after_think = validate_thinking_format(response, True)
@ -245,34 +247,46 @@ class ARCEvalEnv(BaseEnv):
response_to_parse = response
else:
response_to_parse = response
# Primary: Try <answer></answer> tags
letter, method = extract_letter_from_answer_tag(
response_to_parse,
valid_letters,
debug=self.config.full_debug,
choices=choices
choices=choices,
)
if letter:
return letter, method
# Fallback: Use regex patterns
num_choices = len(item['choices']['text'])
num_choices = len(item["choices"]["text"])
fallback_patterns = build_mcqa_fallback_patterns(num_choices)
for priority, pattern, method_name in fallback_patterns:
matches = pattern.findall(response_to_parse)
if matches:
match = matches[-1] if method_name in ["final_answer_is", "the_answer_is", "answer_colon", "answer_space"] else matches[0]
match = (
matches[-1]
if method_name
in [
"final_answer_is",
"the_answer_is",
"answer_colon",
"answer_space",
]
else matches[0]
)
if isinstance(match, tuple):
match = match[0]
letter = match.strip("()").upper()
if letter in valid_letters:
return letter, f"fallback_{method_name}"
return None, "no_match"
async def _generate_with_retry(self, messages: List[Dict], item_id: str) -> Optional[str]:
async def _generate_with_retry(
self, messages: List[Dict], item_id: str
) -> Optional[str]:
"""Generate response with retry logic."""
for attempt in range(self.config.max_retries):
try:
@ -283,70 +297,70 @@ class ARCEvalEnv(BaseEnv):
}
if self.config.eval_max_tokens > 0:
api_params["max_tokens"] = self.config.eval_max_tokens
response = await self.client.chat.completions.create(**api_params)
if response.choices and response.choices[0].message.content:
content = response.choices[0].message.content.strip()
if len(content) >= self.config.min_response_length:
return content
except Exception as e:
if self.config.full_debug:
print(f" Error on item {item_id} attempt {attempt + 1}: {e}")
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay * (attempt + 1))
return None
async def _evaluate_single_item(self, item: Dict, idx: int) -> Dict:
"""Evaluate a single ARC item."""
# Format prompt
prompt, choices = self._format_prompt(item)
# Build messages
messages = []
system_content = self._create_system_content()
if system_content:
messages.append({"role": "system", "content": system_content})
messages.append({"role": "user", "content": prompt})
# Generate response
response = await self._generate_with_retry(messages, str(idx))
if response is None:
return {
"index": idx,
"question_id": item.get('id', ''),
"question_id": item.get("id", ""),
"is_correct": False,
"extracted_answer": None,
"gold_answer": item['answerKey'],
"gold_answer": item["answerKey"],
"extraction_method": "generation_failed",
"error": "Failed to generate response",
}
# Extract answer
extracted_answer, extraction_method = self._extract_answer(response, item)
# Gold answer
gold_answer = item['answerKey']
gold_answer = item["answerKey"]
# Score
is_correct = extracted_answer == gold_answer if extracted_answer else False
result = {
"index": idx,
"question_id": item.get('id', ''),
"question_id": item.get("id", ""),
"is_correct": is_correct,
"extracted_answer": extracted_answer,
"gold_answer": gold_answer,
"extraction_method": extraction_method,
}
if self.config.full_debug:
result["response"] = response
result["prompt"] = prompt
return result
async def evaluate(self, *args, **kwargs):
@ -358,32 +372,34 @@ class ARCEvalEnv(BaseEnv):
print(f" Total questions: {len(self.eval_items)}")
print(f" Thinking mode: {self.config.thinking_mode}")
print("=" * 60)
# Evaluate all items
tasks = [
self._evaluate_single_item(item, idx)
for idx, item in enumerate(self.eval_items)
]
results = await tqdm_asyncio.gather(*tasks, desc=f"Evaluating {self.config.subset}")
results = await tqdm_asyncio.gather(
*tasks, desc=f"Evaluating {self.config.subset}"
)
# Calculate metrics
valid_results = [r for r in results if r.get("gold_answer") is not None]
if not valid_results:
print("Warning: No valid evaluation results obtained")
return
correct = sum(1 for r in valid_results if r["is_correct"])
total = len(valid_results)
accuracy = correct / total if total > 0 else 0.0
# Extraction method breakdown
method_counts = {}
for r in valid_results:
method = r.get("extraction_method", "unknown")
method_counts[method] = method_counts.get(method, 0) + 1
# Print summary
print("\n" + "=" * 60)
print(f"ARC {self.config.subset} Evaluation Results")
@ -396,7 +412,7 @@ class ARCEvalEnv(BaseEnv):
for method, count in sorted(method_counts.items(), key=lambda x: -x[1]):
print(f" {method}: {count} ({count/total:.1%})")
print("=" * 60)
# Save results
metrics = {
"accuracy": accuracy,
@ -405,18 +421,16 @@ class ARCEvalEnv(BaseEnv):
"subset": self.config.subset,
"extraction_methods": method_counts,
}
save_eval_results(
self.config.data_dir_to_save_evals,
metrics,
results
)
self.eval_metrics = [{
"accuracy": accuracy,
"total": total,
"subset": self.config.subset,
}]
save_eval_results(self.config.data_dir_to_save_evals, metrics, results)
self.eval_metrics = [
{
"accuracy": accuracy,
"total": total,
"subset": self.config.subset,
}
]
async def wandb_log(self, step: int):
"""Log metrics to wandb."""
@ -437,4 +451,3 @@ class ARCEvalEnv(BaseEnv):
if __name__ == "__main__":
ARCEvalEnv.cli()