mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
363 lines
12 KiB
Python
363 lines
12 KiB
Python
"""Reward function that combines reasoning format and accuracy rewards."""
|
|
|
|
import logging
|
|
import re
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from .registry import registry
|
|
from .reward_function import RewardFunction
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def parse_reasoning_response(text: str) -> Dict[str, Any]:
|
|
"""
|
|
Parse text to extract thinking section and response section.
|
|
|
|
Args:
|
|
text: Text to parse for thinking and response sections
|
|
|
|
Returns:
|
|
Dictionary with thinking_content, response, and multiple_thinking flag
|
|
"""
|
|
# Check if text is actually a string
|
|
if not isinstance(text, str):
|
|
logger.warning(f"Expected string but got {type(text)}: {text}")
|
|
return {
|
|
"thinking_content": "",
|
|
"response": str(text),
|
|
"multiple_thinking": False,
|
|
}
|
|
|
|
# Find all thinking blocks
|
|
thinking_blocks = re.findall(r"<think>.*?</think>", text, re.DOTALL)
|
|
|
|
# If there's more than one thinking block, fail
|
|
if len(thinking_blocks) > 1:
|
|
return {"thinking_content": "", "response": text, "multiple_thinking": True}
|
|
|
|
# Match the single thinking block if it exists
|
|
pattern = r"<think>\s*(.*?)\s*</think>\s*(.*)"
|
|
match = re.search(pattern, text, re.DOTALL)
|
|
if not match:
|
|
return {"thinking_content": "", "response": text, "multiple_thinking": False}
|
|
|
|
return {
|
|
"thinking_content": match.group(1).strip(),
|
|
"response": match.group(2).strip(),
|
|
"multiple_thinking": False,
|
|
}
|
|
|
|
|
|
@registry.register
|
|
class FormatReasoningReward(RewardFunction):
|
|
"""
|
|
Reward function that checks for proper reasoning format.
|
|
|
|
Checks if completions have:
|
|
1. A thinking section in <think> tags
|
|
2. A response section after the thinking tags
|
|
3. No multiple thinking sections (only one <think> block)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
reward_value: float = 0.5,
|
|
require_thinking: bool = True,
|
|
require_response: bool = True,
|
|
allow_multiple_thinking: bool = False,
|
|
weight: float = 1.0,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize the format reasoning reward function.
|
|
|
|
Args:
|
|
reward_value: Value to award for correct formatting
|
|
require_thinking: Whether to require thinking content
|
|
require_response: Whether to require response content
|
|
allow_multiple_thinking: Whether to allow multiple thinking sections
|
|
weight: Weight for this reward
|
|
**kwargs: Additional configuration
|
|
"""
|
|
super().__init__(weight=weight, **kwargs)
|
|
self.reward_value = reward_value
|
|
self.require_thinking = require_thinking
|
|
self.require_response = require_response
|
|
self.allow_multiple_thinking = allow_multiple_thinking
|
|
|
|
def compute(self, completions: List[Any], **kwargs) -> List[float]:
|
|
"""
|
|
Check if completions have proper reasoning format.
|
|
|
|
Args:
|
|
completions: List of completions to evaluate
|
|
**kwargs: Additional context
|
|
|
|
Returns:
|
|
List of rewards (reward_value for correct format, 0.0 otherwise)
|
|
"""
|
|
parsed = []
|
|
for completion in completions:
|
|
try:
|
|
content = self.get_content(completion)
|
|
parsed.append(parse_reasoning_response(content))
|
|
except Exception as e:
|
|
logger.error(f"Error parsing response: {e}")
|
|
logger.exception(e)
|
|
parsed.append(
|
|
{"thinking_content": "", "response": "", "multiple_thinking": False}
|
|
)
|
|
|
|
rewards = []
|
|
for p in parsed:
|
|
try:
|
|
# Check if response meets format requirements
|
|
valid_format = True
|
|
|
|
if self.require_thinking and not p["thinking_content"]:
|
|
valid_format = False
|
|
|
|
if self.require_response and not p["response"]:
|
|
valid_format = False
|
|
|
|
if not self.allow_multiple_thinking and p["multiple_thinking"]:
|
|
valid_format = False
|
|
|
|
rewards.append(self.reward_value if valid_format else 0.0)
|
|
except Exception as e:
|
|
logger.error(f"Error in format reward calculation: {e}")
|
|
logger.exception(e)
|
|
rewards.append(0.0)
|
|
|
|
return rewards
|
|
|
|
|
|
@registry.register
|
|
class AccuracyXReward(RewardFunction):
|
|
"""
|
|
Reward function that checks if completion responses contain the solution.
|
|
|
|
First parses completions to extract the response part after thinking tags,
|
|
then checks if the solution string is contained in the response.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
exact_match: bool = False,
|
|
case_sensitive: bool = False,
|
|
reward_value: float = 1.0,
|
|
weight: float = 1.0,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize the accuracy reward function.
|
|
|
|
Args:
|
|
exact_match: Whether to require exact match vs. contained
|
|
case_sensitive: Whether to do case-sensitive matching
|
|
reward_value: Value to award for correct answer
|
|
weight: Weight for this reward
|
|
**kwargs: Additional configuration
|
|
"""
|
|
super().__init__(weight=weight, **kwargs)
|
|
self.exact_match = exact_match
|
|
self.case_sensitive = case_sensitive
|
|
self.reward_value = reward_value
|
|
|
|
def compute(
|
|
self,
|
|
completions: List[Any],
|
|
solution: Optional[Union[str, List[str]]] = None,
|
|
**kwargs,
|
|
) -> List[float]:
|
|
"""
|
|
Check if completion responses contain the solution.
|
|
|
|
Args:
|
|
completions: List of completions to evaluate
|
|
solution: The solution to check for
|
|
**kwargs: Additional context
|
|
|
|
Returns:
|
|
List of rewards (reward_value for correct, 0.0 otherwise)
|
|
"""
|
|
if solution is None:
|
|
logger.warning("No solution provided for accuracy reward")
|
|
return [0.0] * len(completions)
|
|
|
|
parsed_responses = []
|
|
for completion in completions:
|
|
try:
|
|
content = self.get_content(completion)
|
|
parsed_responses.append(parse_reasoning_response(content))
|
|
except Exception as e:
|
|
logger.error(f"Error parsing response: {e}")
|
|
logger.exception(e)
|
|
parsed_responses.append(
|
|
{"thinking_content": "", "response": "", "multiple_thinking": False}
|
|
)
|
|
|
|
rewards = []
|
|
|
|
# Ensure solution is in the right format
|
|
if not isinstance(solution, list):
|
|
solution = [solution] * len(parsed_responses)
|
|
|
|
for resp, sol in zip(parsed_responses, solution):
|
|
try:
|
|
# Extract solution content if needed
|
|
sol_content = self.get_content(sol) if not isinstance(sol, str) else sol
|
|
resp_content = resp["response"]
|
|
|
|
# Do the matching based on settings
|
|
if not self.case_sensitive:
|
|
sol_content = sol_content.lower()
|
|
resp_content = resp_content.lower()
|
|
|
|
if self.exact_match:
|
|
match = resp_content == sol_content
|
|
else:
|
|
match = sol_content in resp_content
|
|
|
|
rewards.append(self.reward_value if match else 0.0)
|
|
except Exception as e:
|
|
logger.error(f"Error in accuracy reward calculation: {e}")
|
|
logger.exception(e)
|
|
rewards.append(0.0)
|
|
|
|
return rewards
|
|
|
|
|
|
@registry.register
|
|
class R1Reward(RewardFunction):
|
|
"""
|
|
Combined reward function that rewards both reasoning format and accuracy.
|
|
|
|
This reward function combines:
|
|
1. FormatReasoningReward - rewards for proper <think> and response formatting
|
|
2. AccuracyXReward - rewards for having the solution in the response
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
format_weight: float = 0.5,
|
|
accuracy_weight: float = 1.0,
|
|
weight: float = 1.0,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize the R1 reward function.
|
|
|
|
Args:
|
|
format_weight: Weight for the format component
|
|
accuracy_weight: Weight for the accuracy component
|
|
weight: Weight for the overall reward
|
|
**kwargs: Additional configuration
|
|
"""
|
|
super().__init__(weight=weight, **kwargs)
|
|
self.format_weight = format_weight
|
|
self.accuracy_weight = accuracy_weight
|
|
|
|
# Create component reward functions
|
|
self.format_reward_fn = FormatReasoningReward(
|
|
reward_value=1.0, # Use 1.0 here, we'll apply weight in compute
|
|
weight=1.0, # This will be overridden
|
|
)
|
|
|
|
self.accuracy_reward_fn = AccuracyXReward(
|
|
reward_value=1.0, # Use 1.0 here, we'll apply weight in compute
|
|
weight=1.0, # This will be overridden
|
|
)
|
|
|
|
def compute(
|
|
self,
|
|
completions: List[Any],
|
|
solution: Optional[Union[str, List[str]]] = None,
|
|
**kwargs,
|
|
) -> List[float]:
|
|
"""
|
|
Calculate combined format and accuracy rewards.
|
|
|
|
Args:
|
|
completions: List of completions to evaluate
|
|
solution: The solution to check for accuracy
|
|
**kwargs: Additional context
|
|
|
|
Returns:
|
|
List of combined rewards
|
|
"""
|
|
try:
|
|
# Calculate component rewards
|
|
format_rewards = self.format_reward_fn.compute(completions, **kwargs)
|
|
accuracy_rewards = self.accuracy_reward_fn.compute(
|
|
completions, solution=solution, **kwargs
|
|
)
|
|
|
|
# Apply component weights and combine
|
|
rewards = [
|
|
(f * self.format_weight) + (a * self.accuracy_weight)
|
|
for f, a in zip(format_rewards, accuracy_rewards)
|
|
]
|
|
|
|
logger.info(
|
|
f"R1 rewards: accuracy={accuracy_rewards}, format={format_rewards}, combined={rewards}"
|
|
)
|
|
return rewards
|
|
except Exception as e:
|
|
logger.error(f"Error in r1_reward: {e}")
|
|
logger.exception(e)
|
|
# Return zero rewards for all completions
|
|
return [0.0] * len(completions)
|
|
|
|
|
|
# Legacy function for backward compatibility
|
|
def format_reasoning_reward(completions: List[Any], **kwargs) -> List[float]:
|
|
"""
|
|
Legacy function wrapper for FormatReasoningReward.
|
|
|
|
Args:
|
|
completions: List of completions to evaluate
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
List of rewards for format quality
|
|
"""
|
|
reward_fn = FormatReasoningReward(reward_value=0.5)
|
|
return reward_fn.compute(completions, **kwargs)
|
|
|
|
|
|
def accuracy_reward(
|
|
completions: List[Any], solution: Union[str, List[str]], **kwargs
|
|
) -> List[float]:
|
|
"""
|
|
Legacy function wrapper for AccuracyXReward.
|
|
|
|
Args:
|
|
completions: List of completions to evaluate
|
|
solution: The solution to check for
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
List of rewards for accuracy
|
|
"""
|
|
reward_fn = AccuracyXReward()
|
|
return reward_fn.compute(completions, solution=solution, **kwargs)
|
|
|
|
|
|
def r1_reward(
|
|
completions: List[Any], solution: Union[str, List[str]], **kwargs
|
|
) -> List[float]:
|
|
"""
|
|
Legacy function wrapper for R1Reward.
|
|
|
|
Args:
|
|
completions: List of completions to evaluate
|
|
solution: The solution to check for
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
List of combined rewards
|
|
"""
|
|
reward_fn = R1Reward()
|
|
return reward_fn.compute(completions, solution=solution, **kwargs)
|