[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-12-24 10:48:20 +00:00
parent ef9c0c3699
commit afab28dfa9
37 changed files with 4868 additions and 4052 deletions

View file

@ -29,6 +29,14 @@ from typing import Any, Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from eval_helpers import (
create_system_content,
extract_number_from_answer_tag,
extract_thinking_content,
get_default_thinking_prompt,
save_eval_results,
validate_thinking_format,
)
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
@ -38,15 +46,6 @@ from atroposlib.envs.base import (
BaseEnvConfig,
EvalHandlingEnum,
)
from eval_helpers import (
extract_number_from_answer_tag,
validate_thinking_format,
extract_thinking_content,
get_default_thinking_prompt,
create_system_content,
save_eval_results,
)
# Available MuSR subsets
MUSR_SUBSETS = ["murder_mysteries", "object_placements", "team_allocation"]
@ -123,10 +122,10 @@ class MuSREvalConfig(BaseEnvConfig):
class MuSREvalEnv(BaseEnv):
"""
MuSR Evaluation Environment for Atropos.
Evaluates models on multi-step reasoning with long narratives.
"""
name = "musr_eval"
env_config_cls = MuSREvalConfig
@ -140,83 +139,74 @@ class MuSREvalEnv(BaseEnv):
super().__init__(config, server_configs, slurm, testing)
self.config: MuSREvalConfig = config
self.eval_metrics = []
# Regex patterns for thinking mode
self._think_pattern = re.compile(r"<think>")
self._think_close_pattern = re.compile(r"</think>")
self._think_content_pattern = re.compile(r"</think>\s*(.*)", re.DOTALL)
self._thinking_extract_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
# Pre-compile regex for <answer></answer> tag extraction (primary method)
self._answer_tag_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
self._answer_tag_pattern = re.compile(
r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE
)
# Build fallback answer extraction patterns for numbered choices
self._build_extraction_patterns()
def _build_extraction_patterns(self):
"""
Build regex patterns for extracting answer numbers from model responses.
Patterns are ordered by priority (lower number = higher priority).
Takes the LAST match for answer patterns since models often repeat the final answer.
"""
# Number pattern for choices (1-10 to be safe)
num_pattern = r"(\d+)"
# Priority 0: "final answer is: X" with "I hope" (very specific, highest confidence)
self._pattern_final_answer_hope = re.compile(
rf"(?i:final\s+answer\s+is)\s*:?\s*{num_pattern}\.?\s*I\s*hope",
re.IGNORECASE
re.IGNORECASE,
)
# Priority 50: "final answer ... is X" (allows text between)
self._pattern_final_answer_is = re.compile(
rf"(?i:final\s+answer).{{0,100}}?\s+is\s*:?\s*{num_pattern}",
re.IGNORECASE | re.DOTALL
re.IGNORECASE | re.DOTALL,
)
# Priority 75: "the answer is X"
self._pattern_the_answer_is = re.compile(
rf"(?i:the\s+answer\s+is)\s*:?\s*{num_pattern}",
re.IGNORECASE
rf"(?i:the\s+answer\s+is)\s*:?\s*{num_pattern}", re.IGNORECASE
)
# Priority 100: "answer: X" or "Answer: X" (with colon)
self._pattern_answer_colon = re.compile(
rf"(?i:answer)\s*:\s*.{{0,50}}?{num_pattern}",
re.IGNORECASE | re.DOTALL
rf"(?i:answer)\s*:\s*.{{0,50}}?{num_pattern}", re.IGNORECASE | re.DOTALL
)
# Priority 125: "option X" or "choice X"
self._pattern_option = re.compile(
rf"(?i:option|choice)\s+{num_pattern}",
re.IGNORECASE
rf"(?i:option|choice)\s+{num_pattern}", re.IGNORECASE
)
# Priority 150: "answer X" or "Answer X" (without colon)
self._pattern_answer_space = re.compile(
rf"(?i:answer)\s+{num_pattern}",
re.IGNORECASE
rf"(?i:answer)\s+{num_pattern}", re.IGNORECASE
)
# Priority 200: Response starts with number (with optional punctuation)
self._pattern_start = re.compile(
rf"^\s*{num_pattern}[\s\.\)\:]",
re.IGNORECASE
)
self._pattern_start = re.compile(rf"^\s*{num_pattern}[\s\.\)\:]", re.IGNORECASE)
# Priority 210: Number at start of any line
self._pattern_line_start = re.compile(
rf"\n\s*{num_pattern}[\s\.\)\:]",
re.IGNORECASE
rf"\n\s*{num_pattern}[\s\.\)\:]", re.IGNORECASE
)
# Priority 300: Number at end of response
self._pattern_end = re.compile(
rf"{num_pattern}\s*$",
re.IGNORECASE
)
self._pattern_end = re.compile(rf"{num_pattern}\s*$", re.IGNORECASE)
# Store patterns in priority order
self._extraction_patterns = [
(0, self._pattern_final_answer_hope, "final_answer_hope"),
@ -239,7 +229,7 @@ class MuSREvalEnv(BaseEnv):
return create_system_content(
self.config.thinking_mode,
self.config.custom_thinking_prompt,
self.config.custom_system_prompt
self.config.custom_system_prompt,
)
@classmethod
@ -265,7 +255,7 @@ class MuSREvalEnv(BaseEnv):
eval_max_tokens=0,
thinking_mode=True,
)
server_configs = [
APIServerConfig(
model_name="Hermes-3-Llama-3.1-8B",
@ -275,18 +265,18 @@ class MuSREvalEnv(BaseEnv):
num_requests_for_eval=1024,
),
]
return env_config, server_configs
def _format_musr_prompt(self, item: Dict) -> str:
"""Format a MuSR item into a prompt with <answer> tag instruction.
Based on lighteval's musr_prompt but with explicit <answer> tag instruction.
Uses numbered choices (1, 2, 3...) as in the original format.
"""
narrative = item.get("narrative", "")
question = item.get("question", "")
# Parse choices - they're stored as a string representation of a list
choices_raw = item.get("choices", "[]")
if isinstance(choices_raw, str):
@ -296,17 +286,17 @@ class MuSREvalEnv(BaseEnv):
choices = []
else:
choices = choices_raw
num_choices = len(choices)
valid_numbers = ", ".join(str(i+1) for i in range(num_choices))
valid_numbers = ", ".join(str(i + 1) for i in range(num_choices))
query = "Read the narrative and answer the question. Think step by step before answering.\n\n"
query += f"Provide your final answer within <answer></answer> tags, containing only the number ({valid_numbers}).\n\n"
query += "Example format:\n<answer>1</answer>\n\n"
query += f"{narrative}\n\n{question}\n\n"
for i, choice in enumerate(choices):
query += f"{i + 1} - {choice}\n"
return query, choices
async def setup(self) -> None:
@ -318,10 +308,12 @@ class MuSREvalEnv(BaseEnv):
print(f" Thinking mode: {self.config.thinking_mode}")
if self.config.thinking_mode:
print(f" Thinking prompt: {self._get_thinking_prompt()[:100]}...")
if self.config.subset not in MUSR_SUBSETS:
print(f"Warning: Unknown subset '{self.config.subset}'. Available: {MUSR_SUBSETS}")
print(
f"Warning: Unknown subset '{self.config.subset}'. Available: {MUSR_SUBSETS}"
)
try:
# MuSR has splits named after the subsets
dataset = load_dataset(
@ -330,11 +322,11 @@ class MuSREvalEnv(BaseEnv):
)
self.eval_data = list(dataset)
print(f" Loaded {len(self.eval_data)} evaluation items")
except Exception as e:
print(f"Error loading MuSR dataset: {e}")
raise
self.all_eval_items = self.eval_data
self.iter = 0
@ -342,13 +334,13 @@ class MuSREvalEnv(BaseEnv):
"""Validate thinking format and extract content after </think> tags."""
if not self.config.thinking_mode:
return True, response
think_open_count = len(self._think_pattern.findall(response))
think_close_count = len(self._think_close_pattern.findall(response))
if think_open_count != 1 or think_close_count != 1:
return False, response
match = self._think_content_pattern.search(response)
if match:
return True, match.group(1).strip()
@ -362,23 +354,25 @@ class MuSREvalEnv(BaseEnv):
return match.group(1).strip()
return None
def _extract_answer_number(self, response: str, num_choices: int) -> Tuple[Optional[int], str]:
def _extract_answer_number(
self, response: str, num_choices: int
) -> Tuple[Optional[int], str]:
"""
Extract the answer number (1-indexed) from the model's response.
Primary method: Look for <answer></answer> tags (only take the first match).
Fallback: Use priority-ordered regex patterns.
Args:
response: The model's response string (content after </think> in thinking mode)
num_choices: Number of valid choices
Returns:
Tuple of (extracted_number or None, extraction_method used)
"""
if not response:
return None, "empty_response"
# PRIMARY: Try <answer></answer> tags first
# Uses word boundary matching - only accepts if EXACTLY ONE valid number found
number, method = extract_number_from_answer_tag(
@ -386,35 +380,50 @@ class MuSREvalEnv(BaseEnv):
)
if number:
return number, method
# FALLBACK: Try each pattern in priority order
for priority, pattern, method_name in self._extraction_patterns:
matches = pattern.findall(response)
if matches:
# Get the LAST match for answer patterns since final answer is most reliable
match = matches[-1] if method_name in ["final_answer_is", "the_answer_is", "answer_colon", "answer_space", "option"] else matches[0]
match = (
matches[-1]
if method_name
in [
"final_answer_is",
"the_answer_is",
"answer_colon",
"answer_space",
"option",
]
else matches[0]
)
try:
num = int(match)
if 1 <= num <= num_choices:
if self.config.full_debug:
print(f" Extracted '{num}' using fallback method '{method_name}' (priority {priority})")
print(
f" Extracted '{num}' using fallback method '{method_name}' (priority {priority})"
)
return num, f"fallback_{method_name}"
except ValueError:
continue
# Last resort: find any number in valid range (take the last one)
numbers = re.findall(r'\b(\d+)\b', response)
numbers = re.findall(r"\b(\d+)\b", response)
for num_str in reversed(numbers):
try:
num = int(num_str)
if 1 <= num <= num_choices:
if self.config.full_debug:
print(f" Extracted '{num}' using fallback 'last_valid_number'")
print(
f" Extracted '{num}' using fallback 'last_valid_number'"
)
return num, "fallback_last_valid_number"
except ValueError:
continue
return None, "no_match"
async def get_next_item(self):
@ -438,17 +447,17 @@ class MuSREvalEnv(BaseEnv):
try:
prompt, choices = self._format_musr_prompt(eval_item)
gold_index = eval_item.get("answer_index", 0) # 0-indexed
if not prompt or not choices:
return {"result": None, "sample": None}
# Build messages
messages = []
system_content = self._create_system_content()
if system_content:
messages.append({"role": "system", "content": system_content})
messages.append({"role": "user", "content": prompt})
# Get model response
model_response = None
finish_reason = None
@ -462,128 +471,175 @@ class MuSREvalEnv(BaseEnv):
}
if self.config.eval_max_tokens > 0:
completion_kwargs["max_tokens"] = self.config.eval_max_tokens
completion = await self.server.chat_completion(**completion_kwargs)
if completion.choices and completion.choices[0].message.content:
model_response = completion.choices[0].message.content
finish_reason = getattr(completion.choices[0], 'finish_reason', None)
if len(model_response.strip()) >= self.config.min_response_length:
finish_reason = getattr(
completion.choices[0], "finish_reason", None
)
if (
len(model_response.strip())
>= self.config.min_response_length
):
break
except Exception as e:
print(f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}")
if hasattr(e, 'response'):
print(
f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}"
)
if hasattr(e, "response"):
try:
print(f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}")
print(
f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}"
)
except:
pass
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay)
else:
return {"result": None, "sample": None}
if not model_response:
return {"result": None, "sample": None}
# Handle thinking mode
thinking_format_valid, response_for_eval = self._validate_thinking_format(model_response)
thinking_format_valid, response_for_eval = self._validate_thinking_format(
model_response
)
thinking_content = None
if self.config.thinking_mode:
thinking_content = self._extract_thinking_content(model_response)
# Extract answer (1-indexed)
extracted_answer, extraction_method = self._extract_answer_number(response_for_eval, len(choices))
extracted_answer, extraction_method = self._extract_answer_number(
response_for_eval, len(choices)
)
# Check correctness (gold_index is 0-indexed, extracted is 1-indexed)
is_correct = extracted_answer is not None and (extracted_answer - 1) == gold_index
is_correct = (
extracted_answer is not None and (extracted_answer - 1) == gold_index
)
sample = {
"narrative": eval_item.get("narrative", "")[:300] + "..." if len(eval_item.get("narrative", "")) > 300 else eval_item.get("narrative", ""),
"narrative": (
eval_item.get("narrative", "")[:300] + "..."
if len(eval_item.get("narrative", "")) > 300
else eval_item.get("narrative", "")
),
"question": eval_item.get("question", ""),
"choices": choices,
"gold_index": gold_index,
"gold_answer": choices[gold_index] if 0 <= gold_index < len(choices) else "N/A",
"model_response": model_response[:500] if len(model_response) > 500 else model_response,
"gold_answer": (
choices[gold_index] if 0 <= gold_index < len(choices) else "N/A"
),
"model_response": (
model_response[:500]
if len(model_response) > 500
else model_response
),
"extracted_answer": extracted_answer,
"extraction_method": extraction_method,
"extracted_choice": choices[extracted_answer - 1] if extracted_answer and 1 <= extracted_answer <= len(choices) else "N/A",
"extracted_choice": (
choices[extracted_answer - 1]
if extracted_answer and 1 <= extracted_answer <= len(choices)
else "N/A"
),
"is_correct": is_correct,
"finish_reason": finish_reason,
"thinking_format_valid": thinking_format_valid,
}
if self.config.thinking_mode:
sample["thinking_content"] = thinking_content[:300] + "..." if thinking_content and len(thinking_content) > 300 else thinking_content
sample["thinking_content"] = (
thinking_content[:300] + "..."
if thinking_content and len(thinking_content) > 300
else thinking_content
)
if self.config.full_debug:
status = "" if is_correct else ""
print(f" [{status}] Extracted: {extracted_answer}, Gold: {gold_index + 1}")
return {
"result": {"correct": is_correct},
"sample": sample
}
print(
f" [{status}] Extracted: {extracted_answer}, Gold: {gold_index + 1}"
)
return {"result": {"correct": is_correct}, "sample": sample}
except Exception as e:
if self.config.full_debug:
print(f"Error in rollout_and_score_eval: {e}")
import traceback
traceback.print_exc()
return {"result": None, "sample": None}
async def evaluate(self, *args, **kwargs) -> None:
"""Run MuSR evaluation."""
start_time = time.time()
print(f"\n{'='*60}")
print(f"Starting MuSR Evaluation ({self.config.subset})")
print(f"{'='*60}")
print(f" Total questions: {len(self.all_eval_items)}")
print(f" Thinking mode: {self.config.thinking_mode}")
print(f"{'='*60}\n")
try:
eval_tasks = [
self.rollout_and_score_eval(item) for item in self.all_eval_items
]
results = await tqdm_asyncio.gather(*eval_tasks, desc=f"Evaluating MuSR ({self.config.subset})")
results = await tqdm_asyncio.gather(
*eval_tasks, desc=f"Evaluating MuSR ({self.config.subset})"
)
valid_results = [
r for r in results
r
for r in results
if r and r.get("sample") is not None and r.get("result") is not None
]
if not valid_results:
print("Warning: No valid evaluation results obtained")
return
except Exception as e:
print(f"Error during evaluation: {e}")
import traceback
traceback.print_exc()
return
end_time = time.time()
# Compute metrics
samples = [r["sample"] for r in valid_results]
total_count = len(valid_results)
correct_count = sum(1 for s in samples if s.get("is_correct", False))
accuracy = correct_count / total_count if total_count > 0 else 0.0
# Answer extraction rate
extracted_count = sum(1 for s in samples if s.get("extracted_answer") is not None)
extracted_count = sum(
1 for s in samples if s.get("extracted_answer") is not None
)
extraction_rate = extracted_count / total_count if total_count > 0 else 0.0
# Thinking metrics
thinking_format_compliant = sum(1 for s in samples if s.get("thinking_format_valid", True))
thinking_format_compliance_rate = thinking_format_compliant / total_count if total_count > 0 else 0.0
thinking_utilization = sum(1 for s in samples if s.get("thinking_content")) if self.config.thinking_mode else 0
thinking_format_compliant = sum(
1 for s in samples if s.get("thinking_format_valid", True)
)
thinking_format_compliance_rate = (
thinking_format_compliant / total_count if total_count > 0 else 0.0
)
thinking_utilization = (
sum(1 for s in samples if s.get("thinking_content"))
if self.config.thinking_mode
else 0
)
eval_metrics = {
"eval/accuracy": accuracy,
"eval/correct_count": correct_count,
@ -592,13 +648,17 @@ class MuSREvalEnv(BaseEnv):
"eval/evaluation_time_seconds": end_time - start_time,
"eval/thinking_mode_enabled": 1.0 if self.config.thinking_mode else 0.0,
}
if self.config.thinking_mode:
eval_metrics["eval/thinking_format_compliance_rate"] = thinking_format_compliance_rate
eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization / total_count if total_count > 0 else 0.0
eval_metrics["eval/thinking_format_compliance_rate"] = (
thinking_format_compliance_rate
)
eval_metrics["eval/thinking_utilization_rate"] = (
thinking_utilization / total_count if total_count > 0 else 0.0
)
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
# Print summary
print(f"\n{'='*60}")
print(f"MuSR Evaluation Results ({self.config.subset})")
@ -609,7 +669,7 @@ class MuSREvalEnv(BaseEnv):
if self.config.thinking_mode:
print(f"Thinking Format Compliance: {thinking_format_compliance_rate:.4f}")
print(f"{'='*60}\n")
try:
await self.evaluate_log(
metrics=eval_metrics,
@ -630,18 +690,19 @@ class MuSREvalEnv(BaseEnv):
"""Log metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
for metric_name, metric_value in self.eval_metrics:
wandb_metrics[metric_name] = metric_value
self.eval_metrics = []
wandb_metrics["config/thinking_mode"] = 1.0 if self.config.thinking_mode else 0.0
wandb_metrics["config/thinking_mode"] = (
1.0 if self.config.thinking_mode else 0.0
)
wandb_metrics["config/eval_max_tokens"] = self.config.eval_max_tokens
wandb_metrics["config/subset"] = self.config.subset
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
MuSREvalEnv.cli()