mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
first commit
This commit is contained in:
commit
621d00dd80
89 changed files with 15315 additions and 0 deletions
105
atroposlib/envs/reward_fns/format_reward.py
Normal file
105
atroposlib/envs/reward_fns/format_reward.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""Reward function for checking if completions have specific XML-style tags."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from .registry import registry
|
||||
from .reward_function import RewardFunction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@registry.register
|
||||
class FormatReward(RewardFunction):
|
||||
"""Reward function that checks if completions have XML-style tags."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preferred_tags: Optional[List[str]] = None,
|
||||
require_all_tags: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
weight: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the format reward function.
|
||||
|
||||
Args:
|
||||
preferred_tags: List of tag names to search for (defaults to ['think', 'answer'])
|
||||
require_all_tags: If True, require all tags to be present for a reward
|
||||
case_sensitive: If True, perform case-sensitive tag matching
|
||||
weight: Weight for this reward
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
super().__init__(weight=weight, **kwargs)
|
||||
self.preferred_tags = preferred_tags or ["think", "answer"]
|
||||
self.require_all_tags = require_all_tags
|
||||
self.case_sensitive = case_sensitive
|
||||
|
||||
def compute(self, completions: List[Any], **kwargs) -> List[float]:
|
||||
"""
|
||||
Check if completions have the expected XML-style tags.
|
||||
|
||||
Args:
|
||||
completions: List of completions to evaluate
|
||||
**kwargs: Additional context
|
||||
|
||||
Returns:
|
||||
List of rewards for each completion (1.0 for good format, 0.0 otherwise)
|
||||
"""
|
||||
# Extract content from different possible formats
|
||||
completion_contents = [
|
||||
self.get_content(completion) for completion in completions
|
||||
]
|
||||
|
||||
# For each completion, check for the preferred tags
|
||||
rewards = []
|
||||
|
||||
flags = 0 if self.case_sensitive else re.IGNORECASE
|
||||
flags |= re.DOTALL # Allow . to match newlines
|
||||
|
||||
for content in completion_contents:
|
||||
if self.require_all_tags:
|
||||
# All tags must be present
|
||||
all_tags_present = True
|
||||
for tag in self.preferred_tags:
|
||||
pattern = f"<{tag}>.*?</{tag}>"
|
||||
if not re.search(pattern, content, flags):
|
||||
all_tags_present = False
|
||||
break
|
||||
rewards.append(1.0 if all_tags_present else 0.0)
|
||||
else:
|
||||
# Any tag can be present
|
||||
has_tags = False
|
||||
for tag in self.preferred_tags:
|
||||
pattern = f"<{tag}>.*?</{tag}>"
|
||||
if re.search(pattern, content, flags):
|
||||
has_tags = True
|
||||
break
|
||||
rewards.append(1.0 if has_tags else 0.0)
|
||||
|
||||
# Log the results
|
||||
logger.info(
|
||||
f"Format reward results: {sum(rewards)}/{len(rewards)} completions match format"
|
||||
)
|
||||
return rewards
|
||||
|
||||
|
||||
# Legacy function for backward compatibility
|
||||
def format_reward(
|
||||
completions: List[Any], preferred_tags: Optional[List[str]] = None, **kwargs
|
||||
) -> List[float]:
|
||||
"""
|
||||
Legacy function wrapper for FormatReward.
|
||||
|
||||
Args:
|
||||
completions: List of completions to evaluate
|
||||
preferred_tags: List of tag names to search for (defaults to ['think', 'answer'])
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
List of rewards for each completion (1.0 for good format, 0.0 otherwise)
|
||||
"""
|
||||
reward_fn = FormatReward(preferred_tags=preferred_tags)
|
||||
return reward_fn.compute(completions, **kwargs)
|
||||
Loading…
Add table
Add a link
Reference in a new issue