atropos/atroposlib/envs/reward_fns/combined_reward.py
2025-04-29 12:10:10 -07:00

91 lines
3.2 KiB
Python

"""Combined reward function that combines multiple reward functions."""
import logging
from typing import Any, Dict, List, Union
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
@registry.register
class CombinedReward(RewardFunction):
"""Meta reward function that combines multiple reward functions"""
def __init__(
self,
rewards: List[Union[str, Dict]],
normalization: str = "none",
weight: float = 1.0,
**kwargs,
):
"""
Initialize with a list of reward functions to combine.
Args:
rewards: List of reward functions (names or config dicts)
normalization: How to normalize rewards, one of:
- "none": No normalization
- "sum": Divide by sum of weights
- "minmax": Scale to range [0,1] based on min/max values
weight: Weight for this combined reward
**kwargs: Additional parameters
"""
super().__init__(weight=weight, **kwargs)
self.normalization = normalization
self.reward_functions = []
# Initialize all sub-reward functions
for reward_config in rewards:
self.reward_functions.append(registry.create(reward_config))
@property
def name(self) -> str:
"""Get a descriptive name for this combined reward"""
return f"combined({','.join(r.name for r in self.reward_functions)})"
def set_wandb_logger(self, logger):
"""Propagate the WandB logger to all sub-rewards"""
super().set_wandb_logger(logger)
for reward_fn in self.reward_functions:
reward_fn.set_wandb_logger(logger)
def compute(self, completions: List[Any], **kwargs) -> List[float]:
"""Compute combined rewards by calling all sub-rewards"""
if not completions:
return []
# Initialize with zeros
combined_rewards = [0.0] * len(completions)
# Collect all sub-reward values
all_rewards = []
for reward_fn in self.reward_functions:
try:
rewards = reward_fn.compute(completions, **kwargs)
all_rewards.append(rewards)
# Add to combined total (pre-normalization)
for i, r in enumerate(rewards):
combined_rewards[i] += r
except Exception as e:
logger.error(f"Error computing reward for {reward_fn.name}: {e}")
logger.exception(e)
# Apply normalization if needed
if self.normalization == "sum":
total_weight = sum(r.weight for r in self.reward_functions)
if total_weight > 0:
combined_rewards = [r / total_weight for r in combined_rewards]
elif self.normalization == "minmax":
# Avoid division by zero
reward_min = min(combined_rewards) if combined_rewards else 0
reward_max = max(combined_rewards) if combined_rewards else 0
if reward_max > reward_min:
combined_rewards = [
(r - reward_min) / (reward_max - reward_min)
for r in combined_rewards
]
return combined_rewards