atropos/atroposlib/envs/reward_fns/reasoning_steps_reward.py
2025-04-29 12:10:10 -07:00

138 lines
4.8 KiB
Python

"""Reward function for evaluating step-by-step reasoning in completions."""
import logging
import re
from typing import Any, Dict, List, Optional
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
@registry.register
class ReasoningStepsReward(RewardFunction):
r"""
Reward function that evaluates step-by-step reasoning in completions.
Looks for several types of step-by-step reasoning indicators:
1. Numbered step patterns like "Step 1:", "Step 2:"
2. Numbered lists like "1.", "2." at start of line
3. Bullet points with hyphens or asterisks
4. Sequential transition words (First, Second, Next, Finally, etc.)
"""
def __init__(
self,
min_words: int = 10,
min_steps: int = 3,
base_score: float = 0.1,
pattern_weights: Optional[Dict[str, float]] = None,
weight: float = 1.0,
**kwargs,
):
"""
Initialize the reasoning steps reward function.
Args:
min_words: Minimum number of words to consider for base score
min_steps: Number of steps needed for full points in each category
base_score: Base score for having content longer than min_words
pattern_weights: Custom weights for each pattern type (optional)
weight: Weight for this reward
**kwargs: Additional configuration
"""
super().__init__(weight=weight, **kwargs)
self.min_words = min_words
self.min_steps = min_steps
self.base_score = base_score
# Default pattern weights
self.pattern_weights = {
"numbered_steps": 0.5, # Strong indicators
"list_numbers": 0.5, # Strong indicators
"bullet_points": 0.4, # Medium indicators
"transition_words": 0.3, # Weaker indicators
}
# Override with custom weights if provided
if pattern_weights:
self.pattern_weights.update(pattern_weights)
# Patterns for different types of step indicators
self.patterns = {
# Step 1: style numbered steps
"numbered_steps": r"Step\s+\d+[\s:]+",
# Numbered lists (1., 2., etc.)
"list_numbers": r"(?:^|\n)\s*\d+\.\s+",
# Bullet points
"bullet_points": r"(?:^|\n)\s*[\-\*•]\s+",
# Sequential transition words - expanded to include more phrases
"transition_words": r"\b(?:First|Second|Third|Fourth|Fifth|Next|Then|Finally|"
r"Subsequently|Afterward|Lastly|Initially|To begin|Let\'s begin|"
r"I\'ll first|After that|In conclusion|Eventually|Subsequently|"
r"To solve|begin by|understand|analyze|apply|compute)\b",
}
def compute(self, completions: List[Any], **kwargs) -> List[float]:
"""
Calculate reasoning quality scores based on pattern matching.
Args:
completions: List of completions to evaluate
**kwargs: Additional context
Returns:
List of reward scores between 0.0 and 1.0
"""
# Extract content from different possible formats
completion_contents = [
self.get_content(completion) for completion in completions
]
rewards = []
for content in completion_contents:
score = 0.0
pattern_matches = {}
# Check for each type of pattern
for pattern_type, pattern in self.patterns.items():
matches = re.findall(pattern, content, re.IGNORECASE | re.MULTILINE)
pattern_matches[pattern_type] = len(matches)
# Add score based on matches and pattern weight
weight = self.pattern_weights.get(
pattern_type, 0.3
) # Default weight if not specified
score += min(1.0, len(matches) / self.min_steps) * weight
# Add a small base score for any content that has more than just an answer
# This helps differentiate minimal reasoning from no reasoning
if len(content.split()) > self.min_words:
score += self.base_score
# Cap the total score at 1.0
score = min(1.0, score)
rewards.append(score)
logger.info(
f"Reasoning steps reward for completion: {pattern_matches}, score: {score}"
)
return rewards
# Legacy function for backward compatibility
def reasoning_steps_reward(completions: List[Any], **kwargs) -> List[float]:
"""
Legacy function wrapper for ReasoningStepsReward.
Args:
completions: List of completions to evaluate
**kwargs: Additional parameters
Returns:
List of reward scores between 0.0 and 1.0
"""
reward_fn = ReasoningStepsReward()
return reward_fn.compute(completions, **kwargs)