mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
91 lines
3.2 KiB
Python
91 lines
3.2 KiB
Python
"""Combined reward function that combines multiple reward functions."""
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Union
|
|
|
|
from .registry import registry
|
|
from .reward_function import RewardFunction
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@registry.register
|
|
class CombinedReward(RewardFunction):
|
|
"""Meta reward function that combines multiple reward functions"""
|
|
|
|
def __init__(
|
|
self,
|
|
rewards: List[Union[str, Dict]],
|
|
normalization: str = "none",
|
|
weight: float = 1.0,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize with a list of reward functions to combine.
|
|
|
|
Args:
|
|
rewards: List of reward functions (names or config dicts)
|
|
normalization: How to normalize rewards, one of:
|
|
- "none": No normalization
|
|
- "sum": Divide by sum of weights
|
|
- "minmax": Scale to range [0,1] based on min/max values
|
|
weight: Weight for this combined reward
|
|
**kwargs: Additional parameters
|
|
"""
|
|
super().__init__(weight=weight, **kwargs)
|
|
self.normalization = normalization
|
|
self.reward_functions = []
|
|
|
|
# Initialize all sub-reward functions
|
|
for reward_config in rewards:
|
|
self.reward_functions.append(registry.create(reward_config))
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""Get a descriptive name for this combined reward"""
|
|
return f"combined({','.join(r.name for r in self.reward_functions)})"
|
|
|
|
def set_wandb_logger(self, logger):
|
|
"""Propagate the WandB logger to all sub-rewards"""
|
|
super().set_wandb_logger(logger)
|
|
for reward_fn in self.reward_functions:
|
|
reward_fn.set_wandb_logger(logger)
|
|
|
|
def compute(self, completions: List[Any], **kwargs) -> List[float]:
|
|
"""Compute combined rewards by calling all sub-rewards"""
|
|
if not completions:
|
|
return []
|
|
|
|
# Initialize with zeros
|
|
combined_rewards = [0.0] * len(completions)
|
|
|
|
# Collect all sub-reward values
|
|
all_rewards = []
|
|
for reward_fn in self.reward_functions:
|
|
try:
|
|
rewards = reward_fn.compute(completions, **kwargs)
|
|
all_rewards.append(rewards)
|
|
|
|
# Add to combined total (pre-normalization)
|
|
for i, r in enumerate(rewards):
|
|
combined_rewards[i] += r
|
|
except Exception as e:
|
|
logger.error(f"Error computing reward for {reward_fn.name}: {e}")
|
|
logger.exception(e)
|
|
|
|
# Apply normalization if needed
|
|
if self.normalization == "sum":
|
|
total_weight = sum(r.weight for r in self.reward_functions)
|
|
if total_weight > 0:
|
|
combined_rewards = [r / total_weight for r in combined_rewards]
|
|
elif self.normalization == "minmax":
|
|
# Avoid division by zero
|
|
reward_min = min(combined_rewards) if combined_rewards else 0
|
|
reward_max = max(combined_rewards) if combined_rewards else 0
|
|
if reward_max > reward_min:
|
|
combined_rewards = [
|
|
(r - reward_min) / (reward_max - reward_min)
|
|
for r in combined_rewards
|
|
]
|
|
|
|
return combined_rewards
|