removed reward function registry

This commit is contained in:
Shannon Sands 2025-05-12 07:37:38 +10:00
parent 4e7fcd3c9a
commit 137f8381ec
2 changed files with 56 additions and 164 deletions

View file

@ -5,22 +5,21 @@ import random
import re
from typing import Any, Dict, List, Optional, Tuple, Union
from trajectoryhandler.envs.base import (
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
OpenaiConfig,
ScoredDataGroup,
)
from trajectoryhandler.envs.reward_fns import registry
from trajectoryhandler.envs.reward_fns.combined_reward import CombinedReward
from trajectoryhandler.utils.tokenize_for_trainer import tokenize_for_trainer
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from .curriculum import MathCurriculum
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
system_prompt = """You are an expert mathematician. You need to solve the given math problem step-by-step, showing your reasoning clearly. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your final answer in a LaTeX format using \\boxed{your answer here}.
system_prompt = """You are an expert mathematician that can use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering.
You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your final answer in a LaTeX format using \\boxed{your answer here}.
The problems will be given in a LaTeX format, so be sure to follow the LaTeX syntax when writing your answer (although no $ delimiters are necessary).
@ -55,6 +54,8 @@ class InfiniteMathEnvConfig(BaseEnvConfig):
max_attempts_per_problem: int = 3
correct_reward: float = 1.0
incorrect_reward: float = -1.0
think_block_bonus: float = 0.2 # Bonus for a well-formed think block
boxed_answer_bonus: float = 0.2 # Bonus for a well-formed boxed answer
# Length penalty parameters
apply_length_penalty: bool = True
@ -66,16 +67,6 @@ class InfiniteMathEnvConfig(BaseEnvConfig):
temperature: float = 0.7
top_p: float = 0.9
# Reward functions
reward_functions: List[Union[str, Dict[str, Any]]] = ["accuracy", "format", "boxed"]
accuracy_reward_weight: float = 1.0 # Weight for the accuracy reward
format_reward_weight: float = (
0.2 # Weight for the format reward relative to correctness
)
boxed_reward_weight: float = (
0.3 # Weight for the boxed answer reward relative to correctness
)
class InfiniteMathEnv(BaseEnv):
"""Environment for procedurally generated math problems with curriculum advancement."""
@ -103,70 +94,6 @@ class InfiniteMathEnv(BaseEnv):
# Set the system prompt
self.system_prompt = system_prompt
# Initialize reward function
self.reward_function = self._initialize_reward_function()
def _initialize_reward_function(self):
"""Initialize the combined reward function for scoring."""
if hasattr(self.config, "reward_functions") and self.config.reward_functions:
# Configure parameters for specific reward functions
reward_configs = []
for reward_func in self.config.reward_functions:
if isinstance(reward_func, str):
# String name case - handle known rewards with custom params
if reward_func == "accuracy":
# Configure accuracy reward
accuracy_config = {
"type": "accuracy",
"weight": self.config.accuracy_reward_weight,
"params": {
"split_on_think_tag": True, # Only look at what's after </think>
"tolerance": 1e-6, # Tolerance for number comparisons
},
}
logger.info(f"Adding accuracy reward with config: {accuracy_config}")
reward_configs.append(accuracy_config)
elif reward_func == "format":
# Configure format reward with think tags and explicit weight
format_config = {
"type": "format",
"weight": self.config.format_reward_weight,
"params": {
"preferred_tags": ["think"],
},
}
logger.info(f"Adding format reward with config: {format_config}")
reward_configs.append(format_config)
elif reward_func == "boxed":
# Configure boxed reward with proper parameters and explicit weight
boxed_config = {
"type": "boxed",
"weight": self.config.boxed_reward_weight,
"params": {
"require_outside_think": True,
},
}
logger.info(f"Adding boxed reward with config: {boxed_config}")
reward_configs.append(boxed_config)
else:
# Pass through other reward functions as is
logger.info(f"Adding generic reward function: {reward_func}")
reward_configs.append(reward_func)
else:
# Dict case - pass through as is
logger.info(f"Adding reward config: {reward_func}")
reward_configs.append(reward_func)
# Create the reward function(s)
if len(reward_configs) == 1:
logger.info(f"Creating single reward function: {reward_configs[0]}")
return registry.create(reward_configs[0])
else:
logger.info(f"Creating combined reward function with {len(reward_configs)} rewards")
# Add explicit normalization to sum to 1.0
return CombinedReward(rewards=reward_configs, normalization="none")
async def setup(self):
"""Initialize the environment and curriculum."""
logger.info("Setting up InfiniteMathEnv")
@ -340,21 +267,22 @@ class InfiniteMathEnv(BaseEnv):
)
# Log reward function metrics
if hasattr(self, "reward_function") and self.wandb:
if hasattr(self.reward_function, "set_wandb_logger"):
self.reward_function.set_wandb_logger(self.wandb)
# REMOVED: Specific reward function config logging as it's not used anymore
# if hasattr(self, "reward_function") and self.wandb:
# if hasattr(self.reward_function, "set_wandb_logger"):
# self.reward_function.set_wandb_logger(self.wandb)
# Log the reward configurations
if isinstance(self.config.reward_functions, list) and self.config.reward_functions:
# Log the reward configuration
wandb_metrics["reward/format_reward_enabled"] = "format" in self.config.reward_functions
wandb_metrics["reward/boxed_reward_enabled"] = "boxed" in self.config.reward_functions
# # Log the reward configurations
# if isinstance(self.config.reward_functions, list) and self.config.reward_functions:
# # Log the reward configuration
# wandb_metrics["reward/format_reward_enabled"] = "format" in self.config.reward_functions
# wandb_metrics["reward/boxed_reward_enabled"] = "boxed" in self.config.reward_functions
if hasattr(self.config, "format_reward_weight"):
wandb_metrics["reward/format_reward_weight"] = self.config.format_reward_weight
# if hasattr(self.config, "format_reward_weight"):
# wandb_metrics["reward/format_reward_weight"] = self.config.format_reward_weight
if hasattr(self.config, "boxed_reward_weight"):
wandb_metrics["reward/boxed_reward_weight"] = self.config.boxed_reward_weight
# if hasattr(self.config, "boxed_reward_weight"):
# wandb_metrics["reward/boxed_reward_weight"] = self.config.boxed_reward_weight
# Add eval metrics
for item in self.eval_metrics:
@ -502,7 +430,7 @@ class InfiniteMathEnv(BaseEnv):
)
# Extract the boxed answer if present
boxed_answer = self.extract_boxed_answer(after_think_part)
boxed_answer = self._extract_boxed_answer(after_think_part)
if not boxed_answer:
# Try to find the answer in the last line
lines = after_think_part.strip().split("\n")
@ -510,15 +438,15 @@ class InfiniteMathEnv(BaseEnv):
boxed_answer = lines[-1].strip()
# Clean up answers for comparison (remove spaces, convert to lowercase)
model_clean = self.clean_for_comparison(
model_clean = self._clean_for_comparison(
boxed_answer if boxed_answer else after_think_part
)
solution_clean = self.clean_for_comparison(solution)
solution_clean = self._clean_for_comparison(solution)
# Check if they match
return model_clean == solution_clean
def extract_boxed_answer(self, text: str) -> Optional[str]:
def _extract_boxed_answer(self, text: str) -> Optional[str]:
"""Extract answer from a LaTeX boxed expression."""
# Try to find boxed content
boxed_match = re.search(r"\\boxed{([^}]*)}", text)
@ -526,7 +454,7 @@ class InfiniteMathEnv(BaseEnv):
return boxed_match.group(1)
return None
def clean_for_comparison(self, text: str) -> str:
def _clean_for_comparison(self, text: str) -> str:
"""Clean text for comparison."""
# Remove LaTeX commands, spaces, commas, and convert to lowercase
cleaned = re.sub(r"\\[a-zA-Z]+", "", text)
@ -604,86 +532,54 @@ class InfiniteMathEnv(BaseEnv):
scored_data["scores"] = []
scored_data["messages"] = []
# Format completions for reward function evaluation
format_completions = []
# Process each item in the rollout data
for messages, solution, generator_id, level in rollout_group_data:
# Extract the model's answer
model_answer = messages[-1]["content"]
# Add to format completions list for reward function
format_completions.append([{"role": "assistant", "content": model_answer}])
# Record performance in curriculum based on the answer and solution
# This will be updated after the reward functions are applied
# Apply all reward functions
reward_scores = []
unweighted_scores = []
if hasattr(self, "reward_function") and self.reward_function:
try:
# Apply the reward function (which may be a combined reward)
reward_scores = self.reward_function(format_completions, solution=solution)
logger.info(f"Reward scores: {reward_scores}")
# Debug individual rewards if it's a combined reward
if hasattr(self.reward_function, "rewards"):
logger.info(f"Combined reward with {len(self.reward_function.rewards)} components")
for i, reward in enumerate(self.reward_function.rewards):
if hasattr(reward, "compute"):
# Get raw unweighted scores
raw_scores = reward.compute(format_completions, solution=solution)
if hasattr(reward, "weight"):
logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}, weight={reward.weight}")
else:
logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}")
else:
logger.info(f"Using single reward: {type(self.reward_function).__name__}")
except Exception as e:
logger.error(f"Error applying reward functions: {e}")
logger.exception(e)
reward_scores = [0.0] * len(format_completions)
# Now update curriculum based on accuracy reward results
for i, (messages, solution, generator_id, level) in enumerate(rollout_group_data):
# Extract accuracy from the combined reward if available
is_correct = False
if reward_scores and hasattr(self.reward_function, "rewards"):
for reward in self.reward_function.rewards:
if type(reward).__name__ == "AccuracyReward":
# Get raw scores from accuracy reward
accuracy_scores = reward.compute(format_completions, solution=solution)
is_correct = accuracy_scores[i] > 0
break
model_answer = messages[-1]["content"]
current_score = 0.0
# 1. Accuracy Check
is_correct = self.check_answer(model_answer, solution)
if is_correct:
current_score += self.config.correct_reward
else:
current_score += self.config.incorrect_reward
# Record answer correctness for tracking
# Record answer correctness for tracking and curriculum
self.percent_correct_buffer.append(1 if is_correct else 0)
if level is not None:
self.level_correct_buffer[level].append(1 if is_correct else 0)
# Record performance in curriculum
self.curriculum.record_performance(generator_id, is_correct)
# Combine scores and add to scored data
for i, (messages, _, _, _) in enumerate(rollout_group_data):
# Use the reward score directly (all weights are applied)
combined_score = reward_scores[i] if reward_scores else 0.0
logger.info(f"Final score for item {i}: {combined_score}")
# 2. Thinking Block Check
think_match = re.search(r"<think>(.*?)</think>", model_answer, re.DOTALL)
if think_match:
think_content = think_match.group(1).strip()
if think_content: # Check if there's actual content
current_score += self.config.think_block_bonus
# else: penalty for empty think block, or neutral
# else: penalty for missing think block, or neutral
# 3. Boxed Answer Check
# Extract the part after the thinking block for boxed answer validation
after_think_part = model_answer.split("</think>")[-1].strip() if "</think>" in model_answer else model_answer
boxed_answer_content = self._extract_boxed_answer(after_think_part)
if boxed_answer_content is not None: # Check if \boxed{} is present and has content
current_score += self.config.boxed_answer_bonus
# else: penalty for missing/malformed boxed answer, or neutral
logger.info(f"Item {i}: Correct: {is_correct}, Think Bonus: {self.config.think_block_bonus if think_match and think_match.group(1).strip() else 0}, Boxed Bonus: {self.config.boxed_answer_bonus if boxed_answer_content is not None else 0}, Final Score: {current_score}")
# Tokenize for the trainer
tokens_dict = tokenize_for_trainer(
self.tokenizer,
messages,
None,
messages, # These are the full messages including system, user, assistant
None, # Not used by this tokenizer function apparently
)
# Add to scored data
scored_data["tokens"].append(tokens_dict["tokens"])
scored_data["masks"].append(tokens_dict["masks"])
scored_data["scores"].append(combined_score)
scored_data["scores"].append(current_score)
scored_data["messages"].append(messages)
# Advance difficulty if criteria met