mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
removed reward function registry
This commit is contained in:
parent
4e7fcd3c9a
commit
137f8381ec
2 changed files with 56 additions and 164 deletions
|
|
@ -5,22 +5,21 @@ import random
|
|||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from trajectoryhandler.envs.base import (
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
OpenaiConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from trajectoryhandler.envs.reward_fns import registry
|
||||
from trajectoryhandler.envs.reward_fns.combined_reward import CombinedReward
|
||||
from trajectoryhandler.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
from .curriculum import MathCurriculum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
system_prompt = """You are an expert mathematician. You need to solve the given math problem step-by-step, showing your reasoning clearly. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your final answer in a LaTeX format using \\boxed{your answer here}.
|
||||
system_prompt = """You are an expert mathematician that can use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering.
|
||||
You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your final answer in a LaTeX format using \\boxed{your answer here}.
|
||||
|
||||
The problems will be given in a LaTeX format, so be sure to follow the LaTeX syntax when writing your answer (although no $ delimiters are necessary).
|
||||
|
||||
|
|
@ -55,6 +54,8 @@ class InfiniteMathEnvConfig(BaseEnvConfig):
|
|||
max_attempts_per_problem: int = 3
|
||||
correct_reward: float = 1.0
|
||||
incorrect_reward: float = -1.0
|
||||
think_block_bonus: float = 0.2 # Bonus for a well-formed think block
|
||||
boxed_answer_bonus: float = 0.2 # Bonus for a well-formed boxed answer
|
||||
|
||||
# Length penalty parameters
|
||||
apply_length_penalty: bool = True
|
||||
|
|
@ -66,16 +67,6 @@ class InfiniteMathEnvConfig(BaseEnvConfig):
|
|||
temperature: float = 0.7
|
||||
top_p: float = 0.9
|
||||
|
||||
# Reward functions
|
||||
reward_functions: List[Union[str, Dict[str, Any]]] = ["accuracy", "format", "boxed"]
|
||||
accuracy_reward_weight: float = 1.0 # Weight for the accuracy reward
|
||||
format_reward_weight: float = (
|
||||
0.2 # Weight for the format reward relative to correctness
|
||||
)
|
||||
boxed_reward_weight: float = (
|
||||
0.3 # Weight for the boxed answer reward relative to correctness
|
||||
)
|
||||
|
||||
|
||||
class InfiniteMathEnv(BaseEnv):
|
||||
"""Environment for procedurally generated math problems with curriculum advancement."""
|
||||
|
|
@ -103,70 +94,6 @@ class InfiniteMathEnv(BaseEnv):
|
|||
# Set the system prompt
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
# Initialize reward function
|
||||
self.reward_function = self._initialize_reward_function()
|
||||
|
||||
def _initialize_reward_function(self):
|
||||
"""Initialize the combined reward function for scoring."""
|
||||
if hasattr(self.config, "reward_functions") and self.config.reward_functions:
|
||||
# Configure parameters for specific reward functions
|
||||
reward_configs = []
|
||||
|
||||
for reward_func in self.config.reward_functions:
|
||||
if isinstance(reward_func, str):
|
||||
# String name case - handle known rewards with custom params
|
||||
if reward_func == "accuracy":
|
||||
# Configure accuracy reward
|
||||
accuracy_config = {
|
||||
"type": "accuracy",
|
||||
"weight": self.config.accuracy_reward_weight,
|
||||
"params": {
|
||||
"split_on_think_tag": True, # Only look at what's after </think>
|
||||
"tolerance": 1e-6, # Tolerance for number comparisons
|
||||
},
|
||||
}
|
||||
logger.info(f"Adding accuracy reward with config: {accuracy_config}")
|
||||
reward_configs.append(accuracy_config)
|
||||
elif reward_func == "format":
|
||||
# Configure format reward with think tags and explicit weight
|
||||
format_config = {
|
||||
"type": "format",
|
||||
"weight": self.config.format_reward_weight,
|
||||
"params": {
|
||||
"preferred_tags": ["think"],
|
||||
},
|
||||
}
|
||||
logger.info(f"Adding format reward with config: {format_config}")
|
||||
reward_configs.append(format_config)
|
||||
elif reward_func == "boxed":
|
||||
# Configure boxed reward with proper parameters and explicit weight
|
||||
boxed_config = {
|
||||
"type": "boxed",
|
||||
"weight": self.config.boxed_reward_weight,
|
||||
"params": {
|
||||
"require_outside_think": True,
|
||||
},
|
||||
}
|
||||
logger.info(f"Adding boxed reward with config: {boxed_config}")
|
||||
reward_configs.append(boxed_config)
|
||||
else:
|
||||
# Pass through other reward functions as is
|
||||
logger.info(f"Adding generic reward function: {reward_func}")
|
||||
reward_configs.append(reward_func)
|
||||
else:
|
||||
# Dict case - pass through as is
|
||||
logger.info(f"Adding reward config: {reward_func}")
|
||||
reward_configs.append(reward_func)
|
||||
|
||||
# Create the reward function(s)
|
||||
if len(reward_configs) == 1:
|
||||
logger.info(f"Creating single reward function: {reward_configs[0]}")
|
||||
return registry.create(reward_configs[0])
|
||||
else:
|
||||
logger.info(f"Creating combined reward function with {len(reward_configs)} rewards")
|
||||
# Add explicit normalization to sum to 1.0
|
||||
return CombinedReward(rewards=reward_configs, normalization="none")
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize the environment and curriculum."""
|
||||
logger.info("Setting up InfiniteMathEnv")
|
||||
|
|
@ -340,21 +267,22 @@ class InfiniteMathEnv(BaseEnv):
|
|||
)
|
||||
|
||||
# Log reward function metrics
|
||||
if hasattr(self, "reward_function") and self.wandb:
|
||||
if hasattr(self.reward_function, "set_wandb_logger"):
|
||||
self.reward_function.set_wandb_logger(self.wandb)
|
||||
# REMOVED: Specific reward function config logging as it's not used anymore
|
||||
# if hasattr(self, "reward_function") and self.wandb:
|
||||
# if hasattr(self.reward_function, "set_wandb_logger"):
|
||||
# self.reward_function.set_wandb_logger(self.wandb)
|
||||
|
||||
# Log the reward configurations
|
||||
if isinstance(self.config.reward_functions, list) and self.config.reward_functions:
|
||||
# Log the reward configuration
|
||||
wandb_metrics["reward/format_reward_enabled"] = "format" in self.config.reward_functions
|
||||
wandb_metrics["reward/boxed_reward_enabled"] = "boxed" in self.config.reward_functions
|
||||
# # Log the reward configurations
|
||||
# if isinstance(self.config.reward_functions, list) and self.config.reward_functions:
|
||||
# # Log the reward configuration
|
||||
# wandb_metrics["reward/format_reward_enabled"] = "format" in self.config.reward_functions
|
||||
# wandb_metrics["reward/boxed_reward_enabled"] = "boxed" in self.config.reward_functions
|
||||
|
||||
if hasattr(self.config, "format_reward_weight"):
|
||||
wandb_metrics["reward/format_reward_weight"] = self.config.format_reward_weight
|
||||
# if hasattr(self.config, "format_reward_weight"):
|
||||
# wandb_metrics["reward/format_reward_weight"] = self.config.format_reward_weight
|
||||
|
||||
if hasattr(self.config, "boxed_reward_weight"):
|
||||
wandb_metrics["reward/boxed_reward_weight"] = self.config.boxed_reward_weight
|
||||
# if hasattr(self.config, "boxed_reward_weight"):
|
||||
# wandb_metrics["reward/boxed_reward_weight"] = self.config.boxed_reward_weight
|
||||
|
||||
# Add eval metrics
|
||||
for item in self.eval_metrics:
|
||||
|
|
@ -502,7 +430,7 @@ class InfiniteMathEnv(BaseEnv):
|
|||
)
|
||||
|
||||
# Extract the boxed answer if present
|
||||
boxed_answer = self.extract_boxed_answer(after_think_part)
|
||||
boxed_answer = self._extract_boxed_answer(after_think_part)
|
||||
if not boxed_answer:
|
||||
# Try to find the answer in the last line
|
||||
lines = after_think_part.strip().split("\n")
|
||||
|
|
@ -510,15 +438,15 @@ class InfiniteMathEnv(BaseEnv):
|
|||
boxed_answer = lines[-1].strip()
|
||||
|
||||
# Clean up answers for comparison (remove spaces, convert to lowercase)
|
||||
model_clean = self.clean_for_comparison(
|
||||
model_clean = self._clean_for_comparison(
|
||||
boxed_answer if boxed_answer else after_think_part
|
||||
)
|
||||
solution_clean = self.clean_for_comparison(solution)
|
||||
solution_clean = self._clean_for_comparison(solution)
|
||||
|
||||
# Check if they match
|
||||
return model_clean == solution_clean
|
||||
|
||||
def extract_boxed_answer(self, text: str) -> Optional[str]:
|
||||
def _extract_boxed_answer(self, text: str) -> Optional[str]:
|
||||
"""Extract answer from a LaTeX boxed expression."""
|
||||
# Try to find boxed content
|
||||
boxed_match = re.search(r"\\boxed{([^}]*)}", text)
|
||||
|
|
@ -526,7 +454,7 @@ class InfiniteMathEnv(BaseEnv):
|
|||
return boxed_match.group(1)
|
||||
return None
|
||||
|
||||
def clean_for_comparison(self, text: str) -> str:
|
||||
def _clean_for_comparison(self, text: str) -> str:
|
||||
"""Clean text for comparison."""
|
||||
# Remove LaTeX commands, spaces, commas, and convert to lowercase
|
||||
cleaned = re.sub(r"\\[a-zA-Z]+", "", text)
|
||||
|
|
@ -604,86 +532,54 @@ class InfiniteMathEnv(BaseEnv):
|
|||
scored_data["scores"] = []
|
||||
scored_data["messages"] = []
|
||||
|
||||
# Format completions for reward function evaluation
|
||||
format_completions = []
|
||||
|
||||
# Process each item in the rollout data
|
||||
for messages, solution, generator_id, level in rollout_group_data:
|
||||
# Extract the model's answer
|
||||
model_answer = messages[-1]["content"]
|
||||
|
||||
# Add to format completions list for reward function
|
||||
format_completions.append([{"role": "assistant", "content": model_answer}])
|
||||
|
||||
# Record performance in curriculum based on the answer and solution
|
||||
# This will be updated after the reward functions are applied
|
||||
|
||||
# Apply all reward functions
|
||||
reward_scores = []
|
||||
unweighted_scores = []
|
||||
if hasattr(self, "reward_function") and self.reward_function:
|
||||
try:
|
||||
# Apply the reward function (which may be a combined reward)
|
||||
reward_scores = self.reward_function(format_completions, solution=solution)
|
||||
logger.info(f"Reward scores: {reward_scores}")
|
||||
|
||||
# Debug individual rewards if it's a combined reward
|
||||
if hasattr(self.reward_function, "rewards"):
|
||||
logger.info(f"Combined reward with {len(self.reward_function.rewards)} components")
|
||||
for i, reward in enumerate(self.reward_function.rewards):
|
||||
if hasattr(reward, "compute"):
|
||||
# Get raw unweighted scores
|
||||
raw_scores = reward.compute(format_completions, solution=solution)
|
||||
if hasattr(reward, "weight"):
|
||||
logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}, weight={reward.weight}")
|
||||
else:
|
||||
logger.info(f"Reward {i} ({type(reward).__name__}): raw={raw_scores}")
|
||||
else:
|
||||
logger.info(f"Using single reward: {type(self.reward_function).__name__}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying reward functions: {e}")
|
||||
logger.exception(e)
|
||||
reward_scores = [0.0] * len(format_completions)
|
||||
|
||||
# Now update curriculum based on accuracy reward results
|
||||
for i, (messages, solution, generator_id, level) in enumerate(rollout_group_data):
|
||||
# Extract accuracy from the combined reward if available
|
||||
is_correct = False
|
||||
if reward_scores and hasattr(self.reward_function, "rewards"):
|
||||
for reward in self.reward_function.rewards:
|
||||
if type(reward).__name__ == "AccuracyReward":
|
||||
# Get raw scores from accuracy reward
|
||||
accuracy_scores = reward.compute(format_completions, solution=solution)
|
||||
is_correct = accuracy_scores[i] > 0
|
||||
break
|
||||
model_answer = messages[-1]["content"]
|
||||
current_score = 0.0
|
||||
|
||||
# 1. Accuracy Check
|
||||
is_correct = self.check_answer(model_answer, solution)
|
||||
if is_correct:
|
||||
current_score += self.config.correct_reward
|
||||
else:
|
||||
current_score += self.config.incorrect_reward
|
||||
|
||||
# Record answer correctness for tracking
|
||||
# Record answer correctness for tracking and curriculum
|
||||
self.percent_correct_buffer.append(1 if is_correct else 0)
|
||||
if level is not None:
|
||||
self.level_correct_buffer[level].append(1 if is_correct else 0)
|
||||
|
||||
# Record performance in curriculum
|
||||
self.curriculum.record_performance(generator_id, is_correct)
|
||||
|
||||
# Combine scores and add to scored data
|
||||
for i, (messages, _, _, _) in enumerate(rollout_group_data):
|
||||
# Use the reward score directly (all weights are applied)
|
||||
combined_score = reward_scores[i] if reward_scores else 0.0
|
||||
|
||||
logger.info(f"Final score for item {i}: {combined_score}")
|
||||
|
||||
# 2. Thinking Block Check
|
||||
think_match = re.search(r"<think>(.*?)</think>", model_answer, re.DOTALL)
|
||||
if think_match:
|
||||
think_content = think_match.group(1).strip()
|
||||
if think_content: # Check if there's actual content
|
||||
current_score += self.config.think_block_bonus
|
||||
# else: penalty for empty think block, or neutral
|
||||
# else: penalty for missing think block, or neutral
|
||||
|
||||
# 3. Boxed Answer Check
|
||||
# Extract the part after the thinking block for boxed answer validation
|
||||
after_think_part = model_answer.split("</think>")[-1].strip() if "</think>" in model_answer else model_answer
|
||||
boxed_answer_content = self._extract_boxed_answer(after_think_part)
|
||||
if boxed_answer_content is not None: # Check if \boxed{} is present and has content
|
||||
current_score += self.config.boxed_answer_bonus
|
||||
# else: penalty for missing/malformed boxed answer, or neutral
|
||||
|
||||
logger.info(f"Item {i}: Correct: {is_correct}, Think Bonus: {self.config.think_block_bonus if think_match and think_match.group(1).strip() else 0}, Boxed Bonus: {self.config.boxed_answer_bonus if boxed_answer_content is not None else 0}, Final Score: {current_score}")
|
||||
|
||||
# Tokenize for the trainer
|
||||
tokens_dict = tokenize_for_trainer(
|
||||
self.tokenizer,
|
||||
messages,
|
||||
None,
|
||||
messages, # These are the full messages including system, user, assistant
|
||||
None, # Not used by this tokenizer function apparently
|
||||
)
|
||||
|
||||
# Add to scored data
|
||||
scored_data["tokens"].append(tokens_dict["tokens"])
|
||||
scored_data["masks"].append(tokens_dict["masks"])
|
||||
scored_data["scores"].append(combined_score)
|
||||
scored_data["scores"].append(current_score)
|
||||
scored_data["messages"].append(messages)
|
||||
|
||||
# Advance difficulty if criteria met
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue