mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ef9c0c3699
commit
afab28dfa9
37 changed files with 4868 additions and 4052 deletions
|
|
@ -24,8 +24,8 @@ Supports optional thinking mode with <think></think> tags.
|
|||
Answer must be provided in <answer></answer> tags as a JSON 2D array.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
|
@ -34,6 +34,14 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from eval_helpers import (
|
||||
ANSWER_TAG_PATTERN,
|
||||
create_system_content,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
save_eval_results,
|
||||
validate_thinking_format,
|
||||
)
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
|
|
@ -43,14 +51,6 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
)
|
||||
from eval_helpers import (
|
||||
validate_thinking_format,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
create_system_content,
|
||||
save_eval_results,
|
||||
ANSWER_TAG_PATTERN,
|
||||
)
|
||||
|
||||
|
||||
class ARCAGIEvalConfig(BaseEnvConfig):
|
||||
|
|
@ -124,10 +124,10 @@ class ARCAGIEvalConfig(BaseEnvConfig):
|
|||
class ARCAGIEvalEnv(BaseEnv):
|
||||
"""
|
||||
ARC-AGI 2 Evaluation Environment for Atropos.
|
||||
|
||||
|
||||
Evaluates models on abstract reasoning with grid-based pattern puzzles.
|
||||
"""
|
||||
|
||||
|
||||
name = "arc_agi_eval"
|
||||
env_config_cls = ARCAGIEvalConfig
|
||||
|
||||
|
|
@ -173,15 +173,17 @@ class ARCAGIEvalEnv(BaseEnv):
|
|||
print(f" Evaluation split: {self.config.eval_split}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
if self.config.thinking_mode:
|
||||
print(f" Thinking prompt: {get_default_thinking_prompt(self.config.custom_thinking_prompt)[:80]}...")
|
||||
|
||||
print(
|
||||
f" Thinking prompt: {get_default_thinking_prompt(self.config.custom_thinking_prompt)[:80]}..."
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
self.dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
split=self.config.eval_split,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
|
||||
self.eval_items = list(self.dataset)
|
||||
print(f" Loaded {len(self.eval_items)} evaluation items")
|
||||
|
||||
|
|
@ -195,24 +197,24 @@ class ARCAGIEvalEnv(BaseEnv):
|
|||
def _format_prompt(self, item: Dict) -> Tuple[str, List[List[int]]]:
|
||||
"""
|
||||
Format an ARC-AGI 2 item into a prompt.
|
||||
|
||||
|
||||
Returns the formatted prompt and the gold answer grid.
|
||||
"""
|
||||
# Build training examples
|
||||
training_pairs = item["fewshots"]
|
||||
training_examples = ""
|
||||
|
||||
|
||||
for i, pair in enumerate(training_pairs):
|
||||
training_examples += f"--Example {i + 1}--\n\n"
|
||||
training_examples += "INPUT:\n"
|
||||
training_examples += self._grid_to_string(pair["input"]) + "\n\n"
|
||||
training_examples += "OUTPUT:\n"
|
||||
training_examples += self._grid_to_string(pair["output"]) + "\n\n"
|
||||
|
||||
|
||||
# Test input
|
||||
test_input = self._grid_to_string(item["question"][0]["input"])
|
||||
gold_output = item["question"][0]["output"]
|
||||
|
||||
|
||||
# Build the prompt
|
||||
query = """You are solving an ARC-AGI puzzle. You will be shown training examples where an input grid is transformed into an output grid following a specific pattern or rule.
|
||||
|
||||
|
|
@ -242,8 +244,10 @@ Example format:
|
|||
[3, 4, 5],
|
||||
[6, 7, 8]]
|
||||
</answer>
|
||||
""".format(training_examples=training_examples, test_input=test_input)
|
||||
|
||||
""".format(
|
||||
training_examples=training_examples, test_input=test_input
|
||||
)
|
||||
|
||||
return query, gold_output
|
||||
|
||||
def _create_system_content(self) -> Optional[str]:
|
||||
|
|
@ -251,7 +255,7 @@ Example format:
|
|||
return create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt
|
||||
self.config.custom_system_prompt,
|
||||
)
|
||||
|
||||
def _is_valid_grid(self, grid: Any) -> bool:
|
||||
|
|
@ -274,7 +278,7 @@ Example format:
|
|||
def _parse_grid_from_string(self, text: str) -> Optional[List[List[int]]]:
|
||||
"""
|
||||
Parse a 2D grid from a string.
|
||||
|
||||
|
||||
Tries multiple parsing strategies:
|
||||
1. Direct JSON parse of the whole text
|
||||
2. ast.literal_eval (handles Python list syntax)
|
||||
|
|
@ -282,9 +286,9 @@ Example format:
|
|||
"""
|
||||
if not text or not text.strip():
|
||||
return None
|
||||
|
||||
|
||||
text = text.strip()
|
||||
|
||||
|
||||
# Strategy 1: Direct JSON parse
|
||||
try:
|
||||
grid = json.loads(text)
|
||||
|
|
@ -292,7 +296,7 @@ Example format:
|
|||
return grid
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
# Strategy 2: ast.literal_eval (handles Python repr format)
|
||||
try:
|
||||
grid = ast.literal_eval(text)
|
||||
|
|
@ -300,12 +304,12 @@ Example format:
|
|||
return grid
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
|
||||
# Strategy 3: Find the nested array pattern
|
||||
# Look for [[...], [...], ...]
|
||||
nested_pattern = r'\[\s*\[[\d,\s\[\]]+\]\s*\]'
|
||||
nested_pattern = r"\[\s*\[[\d,\s\[\]]+\]\s*\]"
|
||||
matches = re.findall(nested_pattern, text, re.DOTALL)
|
||||
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
grid = ast.literal_eval(match)
|
||||
|
|
@ -313,12 +317,12 @@ Example format:
|
|||
return grid
|
||||
except:
|
||||
continue
|
||||
|
||||
|
||||
# Strategy 4: Extract rows one per line
|
||||
# Look for lines like [0, 1, 2, 3]
|
||||
row_pattern = r'\[\s*\d+(?:\s*,\s*\d+)*\s*\]'
|
||||
row_pattern = r"\[\s*\d+(?:\s*,\s*\d+)*\s*\]"
|
||||
rows = re.findall(row_pattern, text)
|
||||
|
||||
|
||||
if rows:
|
||||
try:
|
||||
grid = [json.loads(row) for row in rows]
|
||||
|
|
@ -326,13 +330,13 @@ Example format:
|
|||
return grid
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
return None
|
||||
|
||||
def _extract_answer(self, response: str) -> Tuple[Optional[List[List[int]]], str]:
|
||||
"""
|
||||
Extract the grid answer from the model's response.
|
||||
|
||||
|
||||
Looks for content inside <answer></answer> tags after </think> (if thinking mode).
|
||||
"""
|
||||
# Get content after </think> if in thinking mode
|
||||
|
|
@ -344,7 +348,7 @@ Example format:
|
|||
response_to_parse = response
|
||||
else:
|
||||
response_to_parse = response
|
||||
|
||||
|
||||
# Try <answer></answer> tags first
|
||||
answer_match = ANSWER_TAG_PATTERN.search(response_to_parse)
|
||||
if answer_match:
|
||||
|
|
@ -354,17 +358,21 @@ Example format:
|
|||
return grid, "answer_tag"
|
||||
else:
|
||||
if self.config.full_debug:
|
||||
print(f" Found answer tag but couldn't parse grid: {answer_content[:100]}...")
|
||||
print(
|
||||
f" Found answer tag but couldn't parse grid: {answer_content[:100]}..."
|
||||
)
|
||||
return None, "answer_tag_parse_failed"
|
||||
|
||||
|
||||
# Fallback: Try to find grid anywhere in response
|
||||
grid = self._parse_grid_from_string(response_to_parse)
|
||||
if grid:
|
||||
return grid, "fallback_grid_search"
|
||||
|
||||
|
||||
return None, "no_match"
|
||||
|
||||
def _grids_match(self, pred_grid: List[List[int]], gold_grid: List[List[int]]) -> bool:
|
||||
def _grids_match(
|
||||
self, pred_grid: List[List[int]], gold_grid: List[List[int]]
|
||||
) -> bool:
|
||||
"""Check if two grids are pixel-perfect matches."""
|
||||
if pred_grid is None or gold_grid is None:
|
||||
return False
|
||||
|
|
@ -377,7 +385,9 @@ Example format:
|
|||
return False
|
||||
return True
|
||||
|
||||
async def _generate_with_retry(self, messages: List[Dict], item_id: str) -> Optional[str]:
|
||||
async def _generate_with_retry(
|
||||
self, messages: List[Dict], item_id: str
|
||||
) -> Optional[str]:
|
||||
"""Generate response with retry logic."""
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
|
|
@ -388,37 +398,37 @@ Example format:
|
|||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
api_params["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
response = await self.client.chat.completions.create(**api_params)
|
||||
|
||||
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content.strip()
|
||||
if len(content) >= self.config.min_response_length:
|
||||
return content
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f" Error on item {item_id} attempt {attempt + 1}: {e}")
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(self.config.retry_delay * (attempt + 1))
|
||||
|
||||
|
||||
return None
|
||||
|
||||
async def _evaluate_single_item(self, item: Dict, idx: int) -> Dict:
|
||||
"""Evaluate a single ARC-AGI 2 item."""
|
||||
# Format prompt
|
||||
prompt, gold_grid = self._format_prompt(item)
|
||||
|
||||
|
||||
# Build messages
|
||||
messages = []
|
||||
system_content = self._create_system_content()
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
|
||||
# Generate response
|
||||
response = await self._generate_with_retry(messages, str(idx))
|
||||
|
||||
|
||||
if response is None:
|
||||
return {
|
||||
"index": idx,
|
||||
|
|
@ -428,13 +438,13 @@ Example format:
|
|||
"extraction_method": "generation_failed",
|
||||
"error": "Failed to generate response",
|
||||
}
|
||||
|
||||
|
||||
# Extract answer
|
||||
extracted_grid, extraction_method = self._extract_answer(response)
|
||||
|
||||
|
||||
# Score - pixel perfect match
|
||||
is_correct = self._grids_match(extracted_grid, gold_grid)
|
||||
|
||||
|
||||
result = {
|
||||
"index": idx,
|
||||
"is_correct": is_correct,
|
||||
|
|
@ -443,13 +453,15 @@ Example format:
|
|||
"extraction_method": extraction_method,
|
||||
"num_training_examples": len(item["fewshots"]),
|
||||
"input_grid_size": f"{len(item['question'][0]['input'])}x{len(item['question'][0]['input'][0])}",
|
||||
"output_grid_size": f"{len(gold_grid)}x{len(gold_grid[0])}" if gold_grid else "unknown",
|
||||
"output_grid_size": (
|
||||
f"{len(gold_grid)}x{len(gold_grid[0])}" if gold_grid else "unknown"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
if self.config.full_debug:
|
||||
result["response"] = response
|
||||
result["prompt"] = prompt
|
||||
|
||||
|
||||
return result
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
|
|
@ -460,34 +472,36 @@ Example format:
|
|||
print(f" Total puzzles: {len(self.eval_items)}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# Evaluate all items
|
||||
tasks = [
|
||||
self._evaluate_single_item(item, idx)
|
||||
for idx, item in enumerate(self.eval_items)
|
||||
]
|
||||
|
||||
|
||||
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating ARC-AGI 2")
|
||||
|
||||
|
||||
# Calculate metrics
|
||||
total = len(results)
|
||||
|
||||
|
||||
if total == 0:
|
||||
print("Warning: No evaluation results obtained")
|
||||
return
|
||||
|
||||
|
||||
correct = sum(1 for r in results if r["is_correct"])
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
|
||||
# Extraction method breakdown
|
||||
method_counts = {}
|
||||
for r in results:
|
||||
method = r.get("extraction_method", "unknown")
|
||||
method_counts[method] = method_counts.get(method, 0) + 1
|
||||
|
||||
|
||||
# Grid size stats
|
||||
successful_extractions = sum(1 for r in results if r["extracted_grid"] is not None)
|
||||
|
||||
successful_extractions = sum(
|
||||
1 for r in results if r["extracted_grid"] is not None
|
||||
)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("ARC-AGI 2 Evaluation Results")
|
||||
|
|
@ -496,13 +510,15 @@ Example format:
|
|||
print(f" Correct (pixel-perfect): {correct}")
|
||||
print(f" Accuracy: {accuracy:.2%}")
|
||||
print("-" * 60)
|
||||
print(f" Successful grid extractions: {successful_extractions}/{total} ({successful_extractions/total:.1%})")
|
||||
print(
|
||||
f" Successful grid extractions: {successful_extractions}/{total} ({successful_extractions/total:.1%})"
|
||||
)
|
||||
print("-" * 60)
|
||||
print(" Extraction Methods:")
|
||||
for method, count in sorted(method_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {method}: {count} ({count/total:.1%})")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# Save results
|
||||
metrics = {
|
||||
"accuracy": accuracy,
|
||||
|
|
@ -512,18 +528,16 @@ Example format:
|
|||
"extraction_rate": successful_extractions / total if total > 0 else 0.0,
|
||||
"extraction_methods": method_counts,
|
||||
}
|
||||
|
||||
save_eval_results(
|
||||
self.config.data_dir_to_save_evals,
|
||||
metrics,
|
||||
results
|
||||
)
|
||||
|
||||
self.eval_metrics = [{
|
||||
"accuracy": accuracy,
|
||||
"total": total,
|
||||
"extraction_rate": successful_extractions / total if total > 0 else 0.0,
|
||||
}]
|
||||
|
||||
save_eval_results(self.config.data_dir_to_save_evals, metrics, results)
|
||||
|
||||
self.eval_metrics = [
|
||||
{
|
||||
"accuracy": accuracy,
|
||||
"total": total,
|
||||
"extraction_rate": successful_extractions / total if total > 0 else 0.0,
|
||||
}
|
||||
]
|
||||
|
||||
async def wandb_log(self, step: int):
|
||||
"""Log metrics to wandb."""
|
||||
|
|
@ -544,4 +558,3 @@ Example format:
|
|||
|
||||
if __name__ == "__main__":
|
||||
ARCAGIEvalEnv.cli()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue