mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
first commit
This commit is contained in:
commit
621d00dd80
89 changed files with 15315 additions and 0 deletions
363
atroposlib/envs/reward_fns/r1_reward.py
Normal file
363
atroposlib/envs/reward_fns/r1_reward.py
Normal file
|
|
@ -0,0 +1,363 @@
|
|||
"""Reward function that combines reasoning format and accuracy rewards."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .registry import registry
|
||||
from .reward_function import RewardFunction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_reasoning_response(text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse text to extract thinking section and response section.
|
||||
|
||||
Args:
|
||||
text: Text to parse for thinking and response sections
|
||||
|
||||
Returns:
|
||||
Dictionary with thinking_content, response, and multiple_thinking flag
|
||||
"""
|
||||
# Check if text is actually a string
|
||||
if not isinstance(text, str):
|
||||
logger.warning(f"Expected string but got {type(text)}: {text}")
|
||||
return {
|
||||
"thinking_content": "",
|
||||
"response": str(text),
|
||||
"multiple_thinking": False,
|
||||
}
|
||||
|
||||
# Find all thinking blocks
|
||||
thinking_blocks = re.findall(r"<think>.*?</think>", text, re.DOTALL)
|
||||
|
||||
# If there's more than one thinking block, fail
|
||||
if len(thinking_blocks) > 1:
|
||||
return {"thinking_content": "", "response": text, "multiple_thinking": True}
|
||||
|
||||
# Match the single thinking block if it exists
|
||||
pattern = r"<think>\s*(.*?)\s*</think>\s*(.*)"
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if not match:
|
||||
return {"thinking_content": "", "response": text, "multiple_thinking": False}
|
||||
|
||||
return {
|
||||
"thinking_content": match.group(1).strip(),
|
||||
"response": match.group(2).strip(),
|
||||
"multiple_thinking": False,
|
||||
}
|
||||
|
||||
|
||||
@registry.register
|
||||
class FormatReasoningReward(RewardFunction):
|
||||
"""
|
||||
Reward function that checks for proper reasoning format.
|
||||
|
||||
Checks if completions have:
|
||||
1. A thinking section in <think> tags
|
||||
2. A response section after the thinking tags
|
||||
3. No multiple thinking sections (only one <think> block)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reward_value: float = 0.5,
|
||||
require_thinking: bool = True,
|
||||
require_response: bool = True,
|
||||
allow_multiple_thinking: bool = False,
|
||||
weight: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the format reasoning reward function.
|
||||
|
||||
Args:
|
||||
reward_value: Value to award for correct formatting
|
||||
require_thinking: Whether to require thinking content
|
||||
require_response: Whether to require response content
|
||||
allow_multiple_thinking: Whether to allow multiple thinking sections
|
||||
weight: Weight for this reward
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
super().__init__(weight=weight, **kwargs)
|
||||
self.reward_value = reward_value
|
||||
self.require_thinking = require_thinking
|
||||
self.require_response = require_response
|
||||
self.allow_multiple_thinking = allow_multiple_thinking
|
||||
|
||||
def compute(self, completions: List[Any], **kwargs) -> List[float]:
|
||||
"""
|
||||
Check if completions have proper reasoning format.
|
||||
|
||||
Args:
|
||||
completions: List of completions to evaluate
|
||||
**kwargs: Additional context
|
||||
|
||||
Returns:
|
||||
List of rewards (reward_value for correct format, 0.0 otherwise)
|
||||
"""
|
||||
parsed = []
|
||||
for completion in completions:
|
||||
try:
|
||||
content = self.get_content(completion)
|
||||
parsed.append(parse_reasoning_response(content))
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing response: {e}")
|
||||
logger.exception(e)
|
||||
parsed.append(
|
||||
{"thinking_content": "", "response": "", "multiple_thinking": False}
|
||||
)
|
||||
|
||||
rewards = []
|
||||
for p in parsed:
|
||||
try:
|
||||
# Check if response meets format requirements
|
||||
valid_format = True
|
||||
|
||||
if self.require_thinking and not p["thinking_content"]:
|
||||
valid_format = False
|
||||
|
||||
if self.require_response and not p["response"]:
|
||||
valid_format = False
|
||||
|
||||
if not self.allow_multiple_thinking and p["multiple_thinking"]:
|
||||
valid_format = False
|
||||
|
||||
rewards.append(self.reward_value if valid_format else 0.0)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in format reward calculation: {e}")
|
||||
logger.exception(e)
|
||||
rewards.append(0.0)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
@registry.register
|
||||
class AccuracyXReward(RewardFunction):
|
||||
"""
|
||||
Reward function that checks if completion responses contain the solution.
|
||||
|
||||
First parses completions to extract the response part after thinking tags,
|
||||
then checks if the solution string is contained in the response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exact_match: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
reward_value: float = 1.0,
|
||||
weight: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the accuracy reward function.
|
||||
|
||||
Args:
|
||||
exact_match: Whether to require exact match vs. contained
|
||||
case_sensitive: Whether to do case-sensitive matching
|
||||
reward_value: Value to award for correct answer
|
||||
weight: Weight for this reward
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
super().__init__(weight=weight, **kwargs)
|
||||
self.exact_match = exact_match
|
||||
self.case_sensitive = case_sensitive
|
||||
self.reward_value = reward_value
|
||||
|
||||
def compute(
|
||||
self,
|
||||
completions: List[Any],
|
||||
solution: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs,
|
||||
) -> List[float]:
|
||||
"""
|
||||
Check if completion responses contain the solution.
|
||||
|
||||
Args:
|
||||
completions: List of completions to evaluate
|
||||
solution: The solution to check for
|
||||
**kwargs: Additional context
|
||||
|
||||
Returns:
|
||||
List of rewards (reward_value for correct, 0.0 otherwise)
|
||||
"""
|
||||
if solution is None:
|
||||
logger.warning("No solution provided for accuracy reward")
|
||||
return [0.0] * len(completions)
|
||||
|
||||
parsed_responses = []
|
||||
for completion in completions:
|
||||
try:
|
||||
content = self.get_content(completion)
|
||||
parsed_responses.append(parse_reasoning_response(content))
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing response: {e}")
|
||||
logger.exception(e)
|
||||
parsed_responses.append(
|
||||
{"thinking_content": "", "response": "", "multiple_thinking": False}
|
||||
)
|
||||
|
||||
rewards = []
|
||||
|
||||
# Ensure solution is in the right format
|
||||
if not isinstance(solution, list):
|
||||
solution = [solution] * len(parsed_responses)
|
||||
|
||||
for resp, sol in zip(parsed_responses, solution):
|
||||
try:
|
||||
# Extract solution content if needed
|
||||
sol_content = self.get_content(sol) if not isinstance(sol, str) else sol
|
||||
resp_content = resp["response"]
|
||||
|
||||
# Do the matching based on settings
|
||||
if not self.case_sensitive:
|
||||
sol_content = sol_content.lower()
|
||||
resp_content = resp_content.lower()
|
||||
|
||||
if self.exact_match:
|
||||
match = resp_content == sol_content
|
||||
else:
|
||||
match = sol_content in resp_content
|
||||
|
||||
rewards.append(self.reward_value if match else 0.0)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in accuracy reward calculation: {e}")
|
||||
logger.exception(e)
|
||||
rewards.append(0.0)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
@registry.register
|
||||
class R1Reward(RewardFunction):
|
||||
"""
|
||||
Combined reward function that rewards both reasoning format and accuracy.
|
||||
|
||||
This reward function combines:
|
||||
1. FormatReasoningReward - rewards for proper <think> and response formatting
|
||||
2. AccuracyXReward - rewards for having the solution in the response
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
format_weight: float = 0.5,
|
||||
accuracy_weight: float = 1.0,
|
||||
weight: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the R1 reward function.
|
||||
|
||||
Args:
|
||||
format_weight: Weight for the format component
|
||||
accuracy_weight: Weight for the accuracy component
|
||||
weight: Weight for the overall reward
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
super().__init__(weight=weight, **kwargs)
|
||||
self.format_weight = format_weight
|
||||
self.accuracy_weight = accuracy_weight
|
||||
|
||||
# Create component reward functions
|
||||
self.format_reward_fn = FormatReasoningReward(
|
||||
reward_value=1.0, # Use 1.0 here, we'll apply weight in compute
|
||||
weight=1.0, # This will be overridden
|
||||
)
|
||||
|
||||
self.accuracy_reward_fn = AccuracyXReward(
|
||||
reward_value=1.0, # Use 1.0 here, we'll apply weight in compute
|
||||
weight=1.0, # This will be overridden
|
||||
)
|
||||
|
||||
def compute(
|
||||
self,
|
||||
completions: List[Any],
|
||||
solution: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs,
|
||||
) -> List[float]:
|
||||
"""
|
||||
Calculate combined format and accuracy rewards.
|
||||
|
||||
Args:
|
||||
completions: List of completions to evaluate
|
||||
solution: The solution to check for accuracy
|
||||
**kwargs: Additional context
|
||||
|
||||
Returns:
|
||||
List of combined rewards
|
||||
"""
|
||||
try:
|
||||
# Calculate component rewards
|
||||
format_rewards = self.format_reward_fn.compute(completions, **kwargs)
|
||||
accuracy_rewards = self.accuracy_reward_fn.compute(
|
||||
completions, solution=solution, **kwargs
|
||||
)
|
||||
|
||||
# Apply component weights and combine
|
||||
rewards = [
|
||||
(f * self.format_weight) + (a * self.accuracy_weight)
|
||||
for f, a in zip(format_rewards, accuracy_rewards)
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"R1 rewards: accuracy={accuracy_rewards}, format={format_rewards}, combined={rewards}"
|
||||
)
|
||||
return rewards
|
||||
except Exception as e:
|
||||
logger.error(f"Error in r1_reward: {e}")
|
||||
logger.exception(e)
|
||||
# Return zero rewards for all completions
|
||||
return [0.0] * len(completions)
|
||||
|
||||
|
||||
# Legacy function for backward compatibility
|
||||
def format_reasoning_reward(completions: List[Any], **kwargs) -> List[float]:
|
||||
"""
|
||||
Legacy function wrapper for FormatReasoningReward.
|
||||
|
||||
Args:
|
||||
completions: List of completions to evaluate
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of rewards for format quality
|
||||
"""
|
||||
reward_fn = FormatReasoningReward(reward_value=0.5)
|
||||
return reward_fn.compute(completions, **kwargs)
|
||||
|
||||
|
||||
def accuracy_reward(
|
||||
completions: List[Any], solution: Union[str, List[str]], **kwargs
|
||||
) -> List[float]:
|
||||
"""
|
||||
Legacy function wrapper for AccuracyXReward.
|
||||
|
||||
Args:
|
||||
completions: List of completions to evaluate
|
||||
solution: The solution to check for
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of rewards for accuracy
|
||||
"""
|
||||
reward_fn = AccuracyXReward()
|
||||
return reward_fn.compute(completions, solution=solution, **kwargs)
|
||||
|
||||
|
||||
def r1_reward(
|
||||
completions: List[Any], solution: Union[str, List[str]], **kwargs
|
||||
) -> List[float]:
|
||||
"""
|
||||
Legacy function wrapper for R1Reward.
|
||||
|
||||
Args:
|
||||
completions: List of completions to evaluate
|
||||
solution: The solution to check for
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of combined rewards
|
||||
"""
|
||||
reward_fn = R1Reward()
|
||||
return reward_fn.compute(completions, solution=solution, **kwargs)
|
||||
Loading…
Add table
Add a link
Reference in a new issue