This commit is contained in:
Prakarsh Kaushik 2026-03-30 21:03:59 +00:00 committed by GitHub
commit 337faab79e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 626 additions and 1 deletions

View file

@ -24,7 +24,8 @@ Usage:
"""
from .combined_reward import CombinedReward
from .ensemble_reward import EnsembleReward
from .registry import registry
from .reward_function import RewardFunction
__all__ = ["RewardFunction", "registry", "CombinedReward"]
__all__ = ["RewardFunction", "registry", "CombinedReward", "EnsembleReward"]

View file

@ -0,0 +1,314 @@
"""
Ensemble reward function with robust aggregation and inter-rater reliability.
Extends the CombinedReward pattern with:
- Multiple aggregation strategies (mean, median, min, majority_vote)
- Inter-rater reliability metrics (Krippendorff's alpha)
- Disagreement tracking for reward hacking detection
Usage:
reward_fn = registry.create("ensemble", rewards=["accuracy", "format"], strategy="median")
scores = reward_fn(completions, **kwargs)
# Access reliability metrics
alpha = reward_fn.last_reliability_alpha
"""
import logging
import warnings
from typing import Any, Dict, List, Optional, Union
import numpy as np
from .registry import registry
from .reward_function import RewardFunction
logger = logging.getLogger(__name__)
def _krippendorff_alpha(ratings_matrix: np.ndarray) -> float:
"""
Compute Krippendorff's alpha for inter-rater reliability.
Uses the interval/ratio metric (squared differences).
Args:
ratings_matrix: Shape (n_raters, n_items). NaN values indicate
missing ratings and are excluded from computation.
Returns:
Alpha value in [-1, 1]. 1 = perfect agreement, 0 = chance agreement,
negative = systematic disagreement.
"""
n_raters, n_items = ratings_matrix.shape
if n_raters < 2 or n_items < 2:
return float("nan")
# Build coincidence matrix approach using pairwise disagreements
# For each item, compute observed disagreement across all rater pairs
observed_disagreement = 0.0
total_pairs = 0
for item_idx in range(n_items):
values = ratings_matrix[:, item_idx]
valid = values[~np.isnan(values)]
n_valid = len(valid)
if n_valid < 2:
continue
# Sum of squared differences for all pairs within this item
for i in range(n_valid):
for j in range(i + 1, n_valid):
observed_disagreement += (valid[i] - valid[j]) ** 2
total_pairs += 1
if total_pairs == 0:
return float("nan")
observed_disagreement /= total_pairs
# Expected disagreement: pairwise differences across ALL values
all_valid = ratings_matrix[~np.isnan(ratings_matrix)]
n_all = len(all_valid)
if n_all < 2:
return float("nan")
expected_disagreement = 0.0
expected_pairs = 0
for i in range(n_all):
for j in range(i + 1, n_all):
expected_disagreement += (all_valid[i] - all_valid[j]) ** 2
expected_pairs += 1
if expected_pairs == 0:
return float("nan")
expected_disagreement /= expected_pairs
if expected_disagreement == 0.0:
# All raters gave identical scores -- perfect agreement
return 1.0
alpha = 1.0 - (observed_disagreement / expected_disagreement)
return float(alpha)
@registry.register
class EnsembleReward(RewardFunction):
"""
Ensemble reward function that aggregates multiple reward functions
with robust strategies and inter-rater reliability tracking.
Compared to CombinedReward, this adds:
- Median and min (conservative) aggregation for robustness
- Majority vote for binary reward environments
- Krippendorff's alpha inter-rater reliability metric
- Per-item disagreement tracking for reward hacking detection
Strategies:
- "mean": Weighted average (same as CombinedReward)
- "median": Median across reward functions (robust to outliers)
- "min": Conservative -- use the minimum score (prevents reward hacking)
- "majority_vote": For binary rewards -- majority wins (ties -> positive)
"""
def __init__(
self,
rewards: List[Union[str, Dict]],
strategy: str = "mean",
weight: float = 1.0,
track_disagreement: bool = True,
**kwargs,
):
"""
Initialize the ensemble reward function.
Args:
rewards: List of reward function names or config dicts.
Resolved via RewardRegistry.
strategy: Aggregation strategy. One of: "mean", "median",
"min", "majority_vote".
weight: Weight for this ensemble when used inside another
CombinedReward.
track_disagreement: If True, track per-item reward variance
for disagreement analysis.
**kwargs: Additional parameters passed to RewardFunction.
"""
super().__init__(weight=weight, **kwargs)
valid_strategies = {"mean", "median", "min", "majority_vote"}
if strategy not in valid_strategies:
raise ValueError(
f"Invalid strategy '{strategy}'. Must be one of: {valid_strategies}"
)
self.strategy = strategy
self.track_disagreement = track_disagreement
self.reward_functions: List[RewardFunction] = []
# Initialize sub-reward functions via registry
for reward_config in rewards:
self.reward_functions.append(registry.create(reward_config))
if len(self.reward_functions) < 2:
warnings.warn(
"EnsembleReward initialized with fewer than 2 reward functions. "
"Inter-rater reliability metrics will not be meaningful.",
stacklevel=2,
)
# State for reliability tracking
self.last_reliability_alpha: float = float("nan")
self.last_disagreement_scores: Optional[List[float]] = None
self._all_sub_rewards: Optional[List[List[float]]] = None
@property
def name(self) -> str:
sub_names = ",".join(r.name for r in self.reward_functions)
return f"ensemble_{self.strategy}({sub_names})"
def set_wandb_logger(self, wandb_logger):
"""Propagate WandB logger to all sub-reward functions."""
super().set_wandb_logger(wandb_logger)
for reward_fn in self.reward_functions:
reward_fn.set_wandb_logger(wandb_logger)
def compute(self, completions: List[Any], **kwargs) -> List[float]:
"""
Compute ensemble reward scores.
Calls all sub-reward functions, aggregates by strategy,
and computes reliability metrics.
Args:
completions: List of completions to evaluate.
**kwargs: Additional context passed to sub-rewards.
Returns:
Aggregated reward scores, one per completion.
"""
if not completions:
return []
n_completions = len(completions)
# Collect all sub-reward scores
all_rewards: List[List[float]] = []
for reward_fn in self.reward_functions:
try:
scores = reward_fn.compute(completions, **kwargs)
if len(scores) != n_completions:
logger.warning(
"Reward function %s returned %d scores for %d completions. "
"Padding/truncating.",
reward_fn.name,
len(scores),
n_completions,
)
# Pad or truncate
if len(scores) < n_completions:
scores = scores + [0.0] * (n_completions - len(scores))
else:
scores = scores[:n_completions]
all_rewards.append(scores)
except Exception as e:
logger.error("Error in reward function %s: %s", reward_fn.name, e)
all_rewards.append([0.0] * n_completions)
self._all_sub_rewards = all_rewards
if not all_rewards:
return [0.0] * n_completions
# Convert to numpy for efficient aggregation
# Shape: (n_reward_fns, n_completions)
reward_matrix = np.array(all_rewards, dtype=np.float64)
# Aggregate by strategy
if self.strategy == "mean":
aggregated = np.mean(reward_matrix, axis=0)
elif self.strategy == "median":
aggregated = np.median(reward_matrix, axis=0)
elif self.strategy == "min":
aggregated = np.min(reward_matrix, axis=0)
elif self.strategy == "majority_vote":
# Treat positive as vote for 1, non-positive as vote for 0
votes = (reward_matrix > 0).astype(np.float64)
vote_fractions = np.mean(votes, axis=0)
# Majority wins; ties (0.5) go to positive
aggregated = np.where(vote_fractions >= 0.5, 1.0, 0.0)
else:
# Should not reach here due to __init__ validation
aggregated = np.mean(reward_matrix, axis=0)
# Compute reliability metrics
self._compute_reliability_metrics(reward_matrix)
# Track per-item disagreement
if self.track_disagreement:
self.last_disagreement_scores = np.var(reward_matrix, axis=0).tolist()
return aggregated.tolist()
def _compute_reliability_metrics(self, reward_matrix: np.ndarray):
"""
Compute and store inter-rater reliability metrics.
Args:
reward_matrix: Shape (n_raters, n_items)
"""
n_raters, n_items = reward_matrix.shape
if n_raters < 2 or n_items < 2:
self.last_reliability_alpha = float("nan")
return
self.last_reliability_alpha = _krippendorff_alpha(reward_matrix)
def reliability_metrics(self) -> Dict[str, float]:
"""
Return the latest inter-rater reliability metrics.
Returns:
Dictionary with reliability statistics:
- alpha: Krippendorff's alpha
- mean_disagreement: Average per-item variance across raters
- max_disagreement: Maximum per-item variance (worst agreement)
"""
metrics = {
"alpha": self.last_reliability_alpha,
}
if self.last_disagreement_scores is not None:
scores = self.last_disagreement_scores
metrics["mean_disagreement"] = sum(scores) / len(scores) if scores else 0.0
metrics["max_disagreement"] = max(scores) if scores else 0.0
return metrics
def log_metrics(self, raw_rewards: List[float], weighted_rewards: List[float]):
"""Log ensemble-specific metrics alongside standard reward metrics."""
super().log_metrics(raw_rewards, weighted_rewards)
if not self.wandb_logger:
return
reliability = self.reliability_metrics()
wandb_metrics = {}
if not np.isnan(reliability.get("alpha", float("nan"))):
wandb_metrics[f"reward/{self.name}/reliability_alpha"] = reliability[
"alpha"
]
if "mean_disagreement" in reliability:
wandb_metrics[f"reward/{self.name}/mean_disagreement"] = reliability[
"mean_disagreement"
]
wandb_metrics[f"reward/{self.name}/max_disagreement"] = reliability[
"max_disagreement"
]
if wandb_metrics:
self.wandb_logger.log(wandb_metrics)

View file

@ -0,0 +1,310 @@
"""
Tests for EnsembleReward -- reward aggregation with inter-rater reliability.
Tests cover:
- All aggregation strategies (mean, median, min, majority_vote)
- Krippendorff's alpha computation (perfect/no agreement)
- Disagreement tracking
- Registry integration
- Edge cases (empty completions, single reward function)
"""
import math
from typing import Any, List
import numpy as np
import pytest
from atroposlib.envs.reward_fns.ensemble_reward import (
EnsembleReward,
_krippendorff_alpha,
)
from atroposlib.envs.reward_fns.registry import RewardRegistry
from atroposlib.envs.reward_fns.reward_function import RewardFunction
# ---------------------------------------------------------------------------
# Test fixtures -- simple reward functions for composing ensembles
# ---------------------------------------------------------------------------
class ConstantReward(RewardFunction):
"""Returns a fixed score for every completion."""
def __init__(self, value: float = 1.0, **kwargs):
super().__init__(**kwargs)
self._value = value
def compute(self, completions: List[Any], **kwargs) -> List[float]:
return [self._value] * len(completions)
class LengthReward(RewardFunction):
"""Scores by string length (for testing divergent reward signals)."""
def compute(self, completions: List[Any], **kwargs) -> List[float]:
return [float(len(self.get_content(c))) for c in completions]
class BinaryReward(RewardFunction):
"""Returns 1.0 if completion contains 'good', else 0.0."""
def compute(self, completions: List[Any], **kwargs) -> List[float]:
return [
1.0 if "good" in self.get_content(c).lower() else 0.0 for c in completions
]
def _make_ensemble(strategy, reward_functions):
"""Helper to construct an EnsembleReward without going through registry."""
ensemble = EnsembleReward.__new__(EnsembleReward)
ensemble.weight = 1.0
ensemble.strategy = strategy
ensemble.track_disagreement = True
ensemble.reward_functions = reward_functions
ensemble.wandb_logger = None
ensemble._name = None
ensemble.config = {}
ensemble.last_reliability_alpha = float("nan")
ensemble.last_disagreement_scores = None
ensemble._all_sub_rewards = None
return ensemble
@pytest.fixture
def test_registry():
"""Create a clean registry with test reward functions."""
reg = RewardRegistry()
reg.register(name="constant")(ConstantReward)
reg.register(name="length")(LengthReward)
reg.register(name="binary")(BinaryReward)
return reg
@pytest.fixture
def completions():
"""Sample completions for testing."""
return ["short", "a medium length string", "good answer here"]
# ---------------------------------------------------------------------------
# Aggregation strategy tests
# ---------------------------------------------------------------------------
class TestMeanAggregation:
def test_mean_of_identical_scores(self, completions):
ensemble = _make_ensemble(
"mean",
[
ConstantReward(value=2.0),
ConstantReward(value=2.0),
],
)
scores = ensemble.compute(completions)
assert len(scores) == 3
assert all(math.isclose(s, 2.0, rel_tol=1e-9) for s in scores)
def test_mean_of_different_scores(self, completions):
ensemble = _make_ensemble(
"mean",
[
ConstantReward(value=1.0),
ConstantReward(value=3.0),
],
)
scores = ensemble.compute(completions)
assert all(math.isclose(s, 2.0, rel_tol=1e-9) for s in scores)
class TestMedianAggregation:
def test_median_rejects_outlier(self, completions):
"""Median should be robust to a single outlier reward function."""
ensemble = _make_ensemble(
"median",
[
ConstantReward(value=1.0),
ConstantReward(value=1.0),
ConstantReward(value=100.0),
],
)
scores = ensemble.compute(completions)
assert all(math.isclose(s, 1.0, rel_tol=1e-9) for s in scores)
class TestMinAggregation:
def test_min_is_conservative(self, completions):
ensemble = _make_ensemble(
"min",
[
ConstantReward(value=0.5),
ConstantReward(value=0.8),
ConstantReward(value=1.0),
],
)
scores = ensemble.compute(completions)
assert all(math.isclose(s, 0.5, rel_tol=1e-9) for s in scores)
class TestMajorityVoteAggregation:
def test_majority_positive(self, completions):
ensemble = _make_ensemble(
"majority_vote",
[
ConstantReward(value=1.0),
ConstantReward(value=1.0),
ConstantReward(value=-1.0),
],
)
scores = ensemble.compute(completions)
assert all(math.isclose(s, 1.0) for s in scores)
def test_majority_negative(self, completions):
ensemble = _make_ensemble(
"majority_vote",
[
ConstantReward(value=-1.0),
ConstantReward(value=-1.0),
ConstantReward(value=1.0),
],
)
scores = ensemble.compute(completions)
assert all(math.isclose(s, 0.0) for s in scores)
def test_tie_goes_positive(self, completions):
ensemble = _make_ensemble(
"majority_vote",
[
ConstantReward(value=1.0),
ConstantReward(value=-1.0),
],
)
scores = ensemble.compute(completions)
assert all(math.isclose(s, 1.0) for s in scores)
# ---------------------------------------------------------------------------
# Inter-rater reliability tests
# ---------------------------------------------------------------------------
class TestKrippendorffAlpha:
def test_perfect_agreement(self):
ratings = np.array(
[
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0],
]
)
alpha = _krippendorff_alpha(ratings)
assert math.isclose(alpha, 1.0, rel_tol=1e-9)
def test_no_agreement(self):
ratings = np.array(
[
[1.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
]
)
alpha = _krippendorff_alpha(ratings)
assert alpha < 0.0
def test_random_agreement(self):
np.random.seed(42)
ratings = np.random.rand(5, 100)
alpha = _krippendorff_alpha(ratings)
assert abs(alpha) < 0.3
def test_insufficient_data(self):
alpha = _krippendorff_alpha(np.array([[1.0, 2.0, 3.0]]))
assert math.isnan(alpha)
alpha = _krippendorff_alpha(np.array([[1.0], [2.0]]))
assert math.isnan(alpha)
class TestReliabilityMetrics:
def test_reliability_computed_after_scoring(self, completions):
ensemble = _make_ensemble(
"mean",
[
ConstantReward(value=1.0),
ConstantReward(value=1.0),
],
)
ensemble.compute(completions)
metrics = ensemble.reliability_metrics()
assert "alpha" in metrics
assert math.isclose(metrics["alpha"], 1.0, rel_tol=1e-9)
def test_disagreement_tracked(self, completions):
ensemble = _make_ensemble(
"mean",
[
ConstantReward(value=0.0),
ConstantReward(value=10.0),
],
)
ensemble.compute(completions)
assert ensemble.last_disagreement_scores is not None
assert len(ensemble.last_disagreement_scores) == len(completions)
# Variance of [0.0, 10.0] = 25.0
assert all(
math.isclose(d, 25.0, rel_tol=1e-9)
for d in ensemble.last_disagreement_scores
)
# ---------------------------------------------------------------------------
# Registry integration
# ---------------------------------------------------------------------------
class TestRegistryIntegration:
def test_create_via_registry(self, test_registry):
# EnsembleReward.__init__ resolves sub-rewards via the global registry,
# so we must register our test fixtures there too.
from atroposlib.envs.reward_fns.registry import registry as global_registry
global_registry.register(name="test_constant")(ConstantReward)
test_registry.register(name="ensemble")(EnsembleReward)
ensemble = test_registry.create(
{
"type": "ensemble",
"rewards": ["test_constant", "test_constant"],
"strategy": "median",
}
)
assert isinstance(ensemble, EnsembleReward)
assert ensemble.strategy == "median"
assert len(ensemble.reward_functions) == 2
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestEdgeCases:
def test_empty_completions(self):
ensemble = _make_ensemble("mean", [ConstantReward(value=1.0)])
scores = ensemble.compute([])
assert scores == []
def test_invalid_strategy_raises(self):
with pytest.raises(ValueError, match="Invalid strategy"):
EnsembleReward(rewards=[], strategy="nonexistent")
def test_name_format(self):
ensemble = _make_ensemble(
"median",
[
ConstantReward(value=1.0),
LengthReward(),
],
)
name = ensemble.name
assert "ensemble_median" in name
assert "constantreward" in name
assert "lengthreward" in name