[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-12-24 10:48:20 +00:00
parent ef9c0c3699
commit afab28dfa9
37 changed files with 4868 additions and 4052 deletions

View file

@ -28,6 +28,16 @@ from typing import Any, Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from eval_helpers import (
ANSWER_TAG_PATTERN,
build_mcqa_fallback_patterns,
create_system_content,
extract_letter_from_answer_tag,
extract_thinking_content,
get_default_thinking_prompt,
save_eval_results,
validate_thinking_format,
)
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
@ -37,16 +47,6 @@ from atroposlib.envs.base import (
BaseEnvConfig,
EvalHandlingEnum,
)
from eval_helpers import (
extract_letter_from_answer_tag,
validate_thinking_format,
extract_thinking_content,
get_default_thinking_prompt,
create_system_content,
save_eval_results,
build_mcqa_fallback_patterns,
ANSWER_TAG_PATTERN,
)
class BoolQEvalConfig(BaseEnvConfig):
@ -120,10 +120,10 @@ class BoolQEvalConfig(BaseEnvConfig):
class BoolQEvalEnv(BaseEnv):
"""
BoolQ Evaluation Environment for Atropos.
Evaluates models on reading comprehension with yes/no questions.
"""
name = "boolq_eval"
env_config_cls = BoolQEvalConfig
@ -137,12 +137,12 @@ class BoolQEvalEnv(BaseEnv):
super().__init__(config, server_configs, slurm, testing)
self.config: BoolQEvalConfig = config
self.eval_metrics = []
# For BoolQ we use Yes/No directly, not letter choices
self._valid_answers = {'yes', 'no'}
self._valid_answers = {"yes", "no"}
# But also support A/B format
self._fallback_patterns = build_mcqa_fallback_patterns(2)
self._valid_letters = {'A', 'B'}
self._valid_letters = {"A", "B"}
@classmethod
def config_init(cls) -> Tuple[BoolQEvalConfig, List[APIServerConfig]]:
@ -175,31 +175,33 @@ class BoolQEvalEnv(BaseEnv):
print(f" Evaluation split: {self.config.eval_split}")
print(f" Thinking mode: {self.config.thinking_mode}")
if self.config.thinking_mode:
print(f" Thinking prompt: {get_default_thinking_prompt(self.config.custom_thinking_prompt)[:80]}...")
print(
f" Thinking prompt: {get_default_thinking_prompt(self.config.custom_thinking_prompt)[:80]}..."
)
# Load dataset
self.dataset = load_dataset(
self.config.dataset_name,
split=self.config.eval_split,
trust_remote_code=True,
)
self.eval_items = list(self.dataset)
print(f" Loaded {len(self.eval_items)} evaluation items")
def _format_prompt(self, item: Dict) -> str:
"""
Format a BoolQ item into a prompt.
BoolQ has a passage and a question that should be answered Yes or No.
"""
passage = item['passage']
question = item['question']
passage = item["passage"]
question = item["question"]
# Clean up double question marks
if question.endswith('??'):
if question.endswith("??"):
question = question[:-1]
# Build the question
query = f"Passage: {passage}\n\n"
query += f"Question: {question}\n\n"
@ -207,7 +209,7 @@ class BoolQEvalEnv(BaseEnv):
query += "A. Yes\n"
query += "B. No\n"
query += "\nProvide your answer in <answer></answer> tags with only the letter (A or B), or 'Yes'/'No'."
return query
def _create_system_content(self) -> Optional[str]:
@ -215,13 +217,13 @@ class BoolQEvalEnv(BaseEnv):
return create_system_content(
self.config.thinking_mode,
self.config.custom_thinking_prompt,
self.config.custom_system_prompt
self.config.custom_system_prompt,
)
def _extract_answer(self, response: str) -> Tuple[Optional[str], str]:
"""
Extract the answer from the model's response.
Accepts:
- A/B letters (converted to Yes/No)
- Yes/No directly
@ -235,71 +237,83 @@ class BoolQEvalEnv(BaseEnv):
response_to_parse = response
else:
response_to_parse = response
# Try <answer></answer> tags first
answer_match = ANSWER_TAG_PATTERN.search(response_to_parse)
if answer_match:
answer_content = answer_match.group(1).strip().lower()
# Direct Yes/No
if 'yes' in answer_content and 'no' not in answer_content:
return 'Yes', 'answer_tag_yes'
if 'no' in answer_content and 'yes' not in answer_content:
return 'No', 'answer_tag_no'
if "yes" in answer_content and "no" not in answer_content:
return "Yes", "answer_tag_yes"
if "no" in answer_content and "yes" not in answer_content:
return "No", "answer_tag_no"
# A/B letters
if answer_content in ['a', 'a.', '(a)']:
return 'Yes', 'answer_tag_letter_a'
if answer_content in ['b', 'b.', '(b)']:
return 'No', 'answer_tag_letter_b'
if answer_content in ["a", "a.", "(a)"]:
return "Yes", "answer_tag_letter_a"
if answer_content in ["b", "b.", "(b)"]:
return "No", "answer_tag_letter_b"
# Check for letter anywhere in short content
if len(answer_content) <= 10:
if 'a' in answer_content and 'b' not in answer_content:
return 'Yes', 'answer_tag_letter_a'
if 'b' in answer_content and 'a' not in answer_content:
return 'No', 'answer_tag_letter_b'
if "a" in answer_content and "b" not in answer_content:
return "Yes", "answer_tag_letter_a"
if "b" in answer_content and "a" not in answer_content:
return "No", "answer_tag_letter_b"
# Fallback: Try letter patterns
letter, method = extract_letter_from_answer_tag(
response_to_parse,
self._valid_letters,
debug=self.config.full_debug,
choices=['Yes', 'No']
choices=["Yes", "No"],
)
if letter:
return 'Yes' if letter == 'A' else 'No', method
return "Yes" if letter == "A" else "No", method
# Fallback: Look for Yes/No in response
response_lower = response_to_parse.lower()
# Check for explicit patterns
yes_patterns = [r'\byes\b', r'\banswer is yes\b', r'\bthe answer is yes\b']
no_patterns = [r'\bno\b', r'\banswer is no\b', r'\bthe answer is no\b']
yes_patterns = [r"\byes\b", r"\banswer is yes\b", r"\bthe answer is yes\b"]
no_patterns = [r"\bno\b", r"\banswer is no\b", r"\bthe answer is no\b"]
yes_matches = sum(1 for p in yes_patterns if re.search(p, response_lower))
no_matches = sum(1 for p in no_patterns if re.search(p, response_lower))
# Only accept if one is clearly dominant
if yes_matches > 0 and no_matches == 0:
return 'Yes', 'fallback_yes_keyword'
return "Yes", "fallback_yes_keyword"
if no_matches > 0 and yes_matches == 0:
return 'No', 'fallback_no_keyword'
return "No", "fallback_no_keyword"
# Try MCQA fallback patterns for A/B
for priority, pattern, method_name in self._fallback_patterns:
matches = pattern.findall(response_to_parse)
if matches:
match = matches[-1] if method_name in ["final_answer_is", "the_answer_is", "answer_colon", "answer_space"] else matches[0]
match = (
matches[-1]
if method_name
in [
"final_answer_is",
"the_answer_is",
"answer_colon",
"answer_space",
]
else matches[0]
)
if isinstance(match, tuple):
match = match[0]
letter = match.strip("()").upper()
if letter in self._valid_letters:
return 'Yes' if letter == 'A' else 'No', f"fallback_{method_name}"
return "Yes" if letter == "A" else "No", f"fallback_{method_name}"
return None, "no_match"
async def _generate_with_retry(self, messages: List[Dict], item_id: str) -> Optional[str]:
async def _generate_with_retry(
self, messages: List[Dict], item_id: str
) -> Optional[str]:
"""Generate response with retry logic."""
for attempt in range(self.config.max_retries):
try:
@ -310,56 +324,56 @@ class BoolQEvalEnv(BaseEnv):
}
if self.config.eval_max_tokens > 0:
api_params["max_tokens"] = self.config.eval_max_tokens
response = await self.client.chat.completions.create(**api_params)
if response.choices and response.choices[0].message.content:
content = response.choices[0].message.content.strip()
if len(content) >= self.config.min_response_length:
return content
except Exception as e:
if self.config.full_debug:
print(f" Error on item {item_id} attempt {attempt + 1}: {e}")
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay * (attempt + 1))
return None
async def _evaluate_single_item(self, item: Dict, idx: int) -> Dict:
"""Evaluate a single BoolQ item."""
# Format prompt
prompt = self._format_prompt(item)
# Build messages
messages = []
system_content = self._create_system_content()
if system_content:
messages.append({"role": "system", "content": system_content})
messages.append({"role": "user", "content": prompt})
# Generate response
response = await self._generate_with_retry(messages, str(idx))
if response is None:
return {
"index": idx,
"is_correct": False,
"extracted_answer": None,
"gold_answer": item['answer'],
"gold_answer": item["answer"],
"extraction_method": "generation_failed",
"error": "Failed to generate response",
}
# Extract answer
extracted_answer, extraction_method = self._extract_answer(response)
# Gold answer (already Yes/No string)
gold_answer = item['answer']
gold_answer = item["answer"]
# Score
is_correct = extracted_answer == gold_answer if extracted_answer else False
result = {
"index": idx,
"is_correct": is_correct,
@ -367,11 +381,11 @@ class BoolQEvalEnv(BaseEnv):
"gold_answer": gold_answer,
"extraction_method": extraction_method,
}
if self.config.full_debug:
result["response"] = response
result["prompt"] = prompt
return result
async def evaluate(self, *args, **kwargs):
@ -382,38 +396,42 @@ class BoolQEvalEnv(BaseEnv):
print(f" Total questions: {len(self.eval_items)}")
print(f" Thinking mode: {self.config.thinking_mode}")
print("=" * 60)
# Evaluate all items
tasks = [
self._evaluate_single_item(item, idx)
for idx, item in enumerate(self.eval_items)
]
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating BoolQ")
# Calculate metrics
valid_results = [r for r in results if r.get("gold_answer") is not None]
if not valid_results:
print("Warning: No valid evaluation results obtained")
return
correct = sum(1 for r in valid_results if r["is_correct"])
total = len(valid_results)
accuracy = correct / total if total > 0 else 0.0
# Extraction method breakdown
method_counts = {}
for r in valid_results:
method = r.get("extraction_method", "unknown")
method_counts[method] = method_counts.get(method, 0) + 1
# Yes/No breakdown
yes_count = sum(1 for r in valid_results if r["gold_answer"] == "Yes")
no_count = sum(1 for r in valid_results if r["gold_answer"] == "No")
yes_correct = sum(1 for r in valid_results if r["gold_answer"] == "Yes" and r["is_correct"])
no_correct = sum(1 for r in valid_results if r["gold_answer"] == "No" and r["is_correct"])
yes_correct = sum(
1 for r in valid_results if r["gold_answer"] == "Yes" and r["is_correct"]
)
no_correct = sum(
1 for r in valid_results if r["gold_answer"] == "No" and r["is_correct"]
)
# Print summary
print("\n" + "=" * 60)
print("BoolQ Evaluation Results")
@ -422,14 +440,18 @@ class BoolQEvalEnv(BaseEnv):
print(f" Correct: {correct}")
print(f" Accuracy: {accuracy:.2%}")
print("-" * 60)
print(f" Yes questions: {yes_count} (correct: {yes_correct}, acc: {yes_correct/yes_count:.2%})")
print(f" No questions: {no_count} (correct: {no_correct}, acc: {no_correct/no_count:.2%})")
print(
f" Yes questions: {yes_count} (correct: {yes_correct}, acc: {yes_correct/yes_count:.2%})"
)
print(
f" No questions: {no_count} (correct: {no_correct}, acc: {no_correct/no_count:.2%})"
)
print("-" * 60)
print(" Extraction Methods:")
for method, count in sorted(method_counts.items(), key=lambda x: -x[1]):
print(f" {method}: {count} ({count/total:.1%})")
print("=" * 60)
# Save results
metrics = {
"accuracy": accuracy,
@ -439,17 +461,15 @@ class BoolQEvalEnv(BaseEnv):
"no_accuracy": no_correct / no_count if no_count > 0 else 0.0,
"extraction_methods": method_counts,
}
save_eval_results(
self.config.data_dir_to_save_evals,
metrics,
results
)
self.eval_metrics = [{
"accuracy": accuracy,
"total": total,
}]
save_eval_results(self.config.data_dir_to_save_evals, metrics, results)
self.eval_metrics = [
{
"accuracy": accuracy,
"total": total,
}
]
async def wandb_log(self, step: int):
"""Log metrics to wandb."""
@ -470,4 +490,3 @@ class BoolQEvalEnv(BaseEnv):
if __name__ == "__main__":
BoolQEvalEnv.cli()