[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-12-24 10:48:20 +00:00
parent ef9c0c3699
commit afab28dfa9
37 changed files with 4868 additions and 4052 deletions

View file

@ -27,6 +27,19 @@ from typing import Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset
from eval_helpers import (
THINK_CONTENT_AFTER_PATTERN,
compare_math_strings,
create_system_content,
extract_boxed_answers,
extract_thinking_content,
format_math_answer_instruction,
get_default_thinking_prompt,
get_math_executor,
save_eval_results,
score_math_answer_async,
validate_thinking_format,
)
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
@ -36,20 +49,6 @@ from atroposlib.envs.base import (
BaseEnvConfig,
EvalHandlingEnum,
)
from eval_helpers import (
validate_thinking_format,
extract_thinking_content,
get_default_thinking_prompt,
create_system_content,
save_eval_results,
score_math_answer_async,
get_math_executor,
extract_boxed_answers,
compare_math_strings,
format_math_answer_instruction,
THINK_CONTENT_AFTER_PATTERN,
)
# Prompt template following lighteval's structure
# Added boxed instruction for consistency with our math verification
@ -60,79 +59,59 @@ GSM8K_PROMPT_TEMPLATE = """Solve the following math problem step by step. {answe
class GSM8KEvalConfig(BaseEnvConfig):
"""Configuration for GSM8K evaluation environment."""
# Dataset configuration
dataset_name: str = Field(
default="openai/gsm8k",
description="HuggingFace dataset name"
default="openai/gsm8k", description="HuggingFace dataset name"
)
subset: str = Field(
default="main",
description="Dataset subset"
)
eval_split: str = Field(
default="test",
description="Split to evaluate on"
)
shuffle_seed: int = Field(
default=42,
description="Random seed for shuffling"
)
subset: str = Field(default="main", description="Dataset subset")
eval_split: str = Field(default="test", description="Split to evaluate on")
shuffle_seed: int = Field(default=42, description="Random seed for shuffling")
# Generation parameters
eval_temperature: float = Field(
default=0.6,
description="Temperature for evaluation generation"
default=0.6, description="Temperature for evaluation generation"
)
eval_max_tokens: int = Field(
default=0,
description="Max tokens for evaluation (0 = use model default)"
default=0, description="Max tokens for evaluation (0 = use model default)"
)
# System prompt configuration
custom_system_prompt: Optional[str] = Field(
default=None,
description="Optional custom system prompt"
default=None, description="Optional custom system prompt"
)
# Thinking mode configuration
thinking_mode: bool = Field(
default=True,
description="Whether to use thinking mode with <think></think> tags"
description="Whether to use thinking mode with <think></think> tags",
)
custom_thinking_prompt: Optional[str] = Field(
default=None,
description="Optional custom thinking prompt"
default=None, description="Optional custom thinking prompt"
)
# Math verification configuration
include_hope_suffix: bool = Field(
default=True,
description="Whether to include 'I hope it is correct' in answer instruction"
description="Whether to include 'I hope it is correct' in answer instruction",
)
max_math_workers: int = Field(
default=64,
description="Maximum workers for math verification ProcessPoolExecutor"
description="Maximum workers for math verification ProcessPoolExecutor",
)
# Retry and debug configuration
max_retries: int = Field(
default=3,
description="Maximum retries for failed API calls"
default=3, description="Maximum retries for failed API calls"
)
retry_delay: float = Field(
default=1.0,
description="Delay between retries in seconds"
default=1.0, description="Delay between retries in seconds"
)
min_response_length: int = Field(
default=1,
description="Minimum response length to consider valid"
default=1, description="Minimum response length to consider valid"
)
full_debug: bool = Field(
default=False,
description="Enable full debug output"
)
full_debug: bool = Field(default=False, description="Enable full debug output")
# Override defaults
group_size: int = 1
max_num_workers: int = 1024
@ -148,7 +127,7 @@ class GSM8KEvalConfig(BaseEnvConfig):
class GSM8KEvalEnv(BaseEnv):
"""
GSM8K Evaluation Environment.
Evaluates grade school math word problem solving using the GSM8K dataset.
Uses math_verify for robust answer verification with string fallback.
"""
@ -175,51 +154,53 @@ class GSM8KEvalEnv(BaseEnv):
async def setup(self) -> None:
"""Initialize the environment and load the dataset."""
await super().setup()
# Initialize math executor
self._math_executor = get_math_executor(self.config.max_math_workers)
if not self._dataset_loaded:
await self._load_dataset()
print(f"\nGSM8K Evaluation Setup (Generative Mode):")
print(f" Dataset: {self.config.dataset_name}")
print(f" Subset: {self.config.subset}")
print(f" Evaluation split: {self.config.eval_split}")
print(f" Thinking mode: {self.config.thinking_mode}")
if self.config.thinking_mode:
thinking_prompt = get_default_thinking_prompt(self.config.custom_thinking_prompt)
thinking_prompt = get_default_thinking_prompt(
self.config.custom_thinking_prompt
)
print(f" Thinking prompt: {thinking_prompt[:80]}...")
print(f" Loaded {len(self.eval_items)} evaluation items")
async def _load_dataset(self) -> None:
"""Load and process the GSM8K dataset."""
print(f"Loading GSM8K dataset: {self.config.dataset_name}/{self.config.subset}...")
print(
f"Loading GSM8K dataset: {self.config.dataset_name}/{self.config.subset}..."
)
try:
dataset = load_dataset(
self.config.dataset_name,
self.config.subset,
trust_remote_code=True
self.config.dataset_name, self.config.subset, trust_remote_code=True
)
except Exception as e:
print(f"Error loading dataset: {e}")
raise
if self.config.eval_split not in dataset:
available_splits = list(dataset.keys())
raise ValueError(
f"Split '{self.config.eval_split}' not found. Available: {available_splits}"
)
split_data = dataset[self.config.eval_split]
# Process items
self.eval_items = []
for idx, item in enumerate(split_data):
question = item.get("question", "")
answer_text = item.get("answer", "")
# Extract the final answer from GSM8K format
# GSM8K answers are formatted as: "reasoning #### final_number"
if "####" in answer_text:
@ -227,18 +208,20 @@ class GSM8KEvalEnv(BaseEnv):
final_answer = parts[-1].strip().replace(",", "")
else:
final_answer = answer_text.strip()
self.eval_items.append({
"id": str(idx),
"question": question,
"answer": final_answer,
"full_answer": answer_text, # Keep full solution for reference
})
self.eval_items.append(
{
"id": str(idx),
"question": question,
"answer": final_answer,
"full_answer": answer_text, # Keep full solution for reference
}
)
# Shuffle with seed
random.seed(self.config.shuffle_seed)
random.shuffle(self.eval_items)
self._dataset_loaded = True
print(f"Loaded {len(self.eval_items)} evaluation items")
@ -248,17 +231,19 @@ class GSM8KEvalEnv(BaseEnv):
include_hope=self.config.include_hope_suffix
)
return GSM8K_PROMPT_TEMPLATE.format(
answer_instruction=answer_instruction,
problem=item["question"]
answer_instruction=answer_instruction, problem=item["question"]
)
def _create_system_content(self) -> str:
"""Create system message content based on thinking mode."""
return create_system_content(
self.config.thinking_mode,
self.config.custom_thinking_prompt,
self.config.custom_system_prompt
) or ""
return (
create_system_content(
self.config.thinking_mode,
self.config.custom_thinking_prompt,
self.config.custom_system_prompt,
)
or ""
)
async def rollout_and_score_eval(
self,
@ -270,12 +255,12 @@ class GSM8KEvalEnv(BaseEnv):
"""
prompt = self._format_prompt(item)
system_content = self._create_system_content()
messages = []
if system_content:
messages.append({"role": "system", "content": system_content})
messages.append({"role": "user", "content": prompt})
# Build API call parameters
kwargs = {
"model": server.model_name,
@ -284,35 +269,38 @@ class GSM8KEvalEnv(BaseEnv):
}
if self.config.eval_max_tokens > 0:
kwargs["max_tokens"] = self.config.eval_max_tokens
response_text = ""
for attempt in range(self.config.max_retries):
try:
response = await self.server.chat_completion(**kwargs)
response_text = response.choices[0].message.content or ""
if len(response_text) >= self.config.min_response_length:
break
except Exception as e:
if self.config.full_debug:
print(f" API error (attempt {attempt + 1}): {e}")
if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay)
continue
if not response_text:
return None
# Validate thinking format and extract content after </think>
is_valid_format, content_for_extraction = validate_thinking_format(
response_text,
self.config.thinking_mode
response_text, self.config.thinking_mode
)
# Extract thinking content if present
thinking_content = extract_thinking_content(response_text) if self.config.thinking_mode else None
thinking_content = (
extract_thinking_content(response_text)
if self.config.thinking_mode
else None
)
# Score using math_verify with string fallback
gold_answer = item["answer"]
is_correct, method, has_multiple_boxed = await score_math_answer_async(
@ -321,19 +309,19 @@ class GSM8KEvalEnv(BaseEnv):
after_think=self.config.thinking_mode,
wrap_gold_boxed=True,
executor=self._math_executor,
debug=self.config.full_debug
debug=self.config.full_debug,
)
# Extract the boxed answer for logging
if self.config.thinking_mode:
match = THINK_CONTENT_AFTER_PATTERN.search(response_text)
score_content = match.group(1) if match else response_text
else:
score_content = response_text
boxed_answers = extract_boxed_answers(score_content)
extracted_answer = boxed_answers[0] if boxed_answers else None
if self.config.full_debug:
print(f"\n--- Item: {item['id']} ---")
print(f"Question: {item['question'][:100]}...")
@ -342,7 +330,7 @@ class GSM8KEvalEnv(BaseEnv):
print(f"Correct: {is_correct} (method: {method})")
if has_multiple_boxed:
print(f"WARNING: Multiple \\boxed{{}} found - marked incorrect")
return {
"item_id": item["id"],
"question": item["question"],
@ -365,47 +353,48 @@ class GSM8KEvalEnv(BaseEnv):
print(f" Total questions: {len(self.eval_items)}")
print(f" Thinking mode: {self.config.thinking_mode}")
print(f"{'='*60}\n")
# Create evaluation tasks
async def eval_task(item):
return await self.rollout_and_score_eval(item, self.server_configs[0])
tasks = [eval_task(item) for item in self.eval_items]
# Run with progress bar
results = await tqdm_asyncio.gather(
*tasks,
desc="Evaluating GSM8K"
)
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating GSM8K")
# Filter out failed results
valid_results = [r for r in results if r is not None]
if not valid_results:
print("Warning: No valid evaluation results obtained")
return {"error": "No valid results", "accuracy": 0.0}
# Calculate metrics
total = len(valid_results)
correct = sum(1 for r in valid_results if r["is_correct"])
accuracy = correct / total if total > 0 else 0.0
# Count verification methods
method_counts = {}
for r in valid_results:
method = r.get("verification_method", "unknown")
method_counts[method] = method_counts.get(method, 0) + 1
# Count multiple boxed failures
multiple_boxed = sum(1 for r in valid_results if r.get("has_multiple_boxed", False))
multiple_boxed = sum(
1 for r in valid_results if r.get("has_multiple_boxed", False)
)
# Format compliance and thinking utilization
format_valid = sum(1 for r in valid_results if r.get("format_valid", True))
has_thinking = sum(1 for r in valid_results if r.get("has_thinking", False))
# Count how many had a boxed answer at all
has_boxed = sum(1 for r in valid_results if r.get("extracted_answer") is not None)
has_boxed = sum(
1 for r in valid_results if r.get("extracted_answer") is not None
)
metrics = {
"accuracy": accuracy,
"total_evaluated": total,
@ -416,7 +405,7 @@ class GSM8KEvalEnv(BaseEnv):
"thinking_utilization_rate": has_thinking / total if total > 0 else 0.0,
"verification_methods": method_counts,
}
print(f"\n{'='*60}")
print("GSM8K Evaluation Results")
print(f"{'='*60}")
@ -431,11 +420,11 @@ class GSM8KEvalEnv(BaseEnv):
for method, count in sorted(method_counts.items(), key=lambda x: -x[1]):
print(f" {method}: {count} ({count/total:.1%})")
print(f"{'='*60}\n")
# Save results
if self.config.data_dir_to_save_evals:
self._save_results(metrics, valid_results)
return metrics
def _save_results(self, metrics: Dict, results: List[Dict]) -> None:
@ -446,16 +435,18 @@ class GSM8KEvalEnv(BaseEnv):
"""Log metrics to Weights & Biases."""
if not self.config.use_wandb:
return
log_dict = {
"gsm8k/accuracy": metrics.get("accuracy", 0),
"gsm8k/total_evaluated": metrics.get("total_evaluated", 0),
"gsm8k/has_boxed_rate": metrics.get("has_boxed_rate", 0),
"gsm8k/multiple_boxed_rate": metrics.get("multiple_boxed_rate", 0),
"gsm8k/format_compliance_rate": metrics.get("format_compliance_rate", 0),
"gsm8k/thinking_utilization_rate": metrics.get("thinking_utilization_rate", 0),
"gsm8k/thinking_utilization_rate": metrics.get(
"thinking_utilization_rate", 0
),
}
wandb.log(log_dict, step=step)
# Required abstract method implementations
@ -474,4 +465,3 @@ class GSM8KEvalEnv(BaseEnv):
if __name__ == "__main__":
GSM8KEvalEnv.cli()