mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge f68ae5e748 into c20c85256e
This commit is contained in:
commit
337faab79e
3 changed files with 626 additions and 1 deletions
|
|
@ -24,7 +24,8 @@ Usage:
|
|||
"""
|
||||
|
||||
from .combined_reward import CombinedReward
|
||||
from .ensemble_reward import EnsembleReward
|
||||
from .registry import registry
|
||||
from .reward_function import RewardFunction
|
||||
|
||||
__all__ = ["RewardFunction", "registry", "CombinedReward"]
|
||||
__all__ = ["RewardFunction", "registry", "CombinedReward", "EnsembleReward"]
|
||||
|
|
|
|||
314
atroposlib/envs/reward_fns/ensemble_reward.py
Normal file
314
atroposlib/envs/reward_fns/ensemble_reward.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
"""
|
||||
Ensemble reward function with robust aggregation and inter-rater reliability.
|
||||
|
||||
Extends the CombinedReward pattern with:
|
||||
- Multiple aggregation strategies (mean, median, min, majority_vote)
|
||||
- Inter-rater reliability metrics (Krippendorff's alpha)
|
||||
- Disagreement tracking for reward hacking detection
|
||||
|
||||
Usage:
|
||||
reward_fn = registry.create("ensemble", rewards=["accuracy", "format"], strategy="median")
|
||||
scores = reward_fn(completions, **kwargs)
|
||||
|
||||
# Access reliability metrics
|
||||
alpha = reward_fn.last_reliability_alpha
|
||||
"""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .registry import registry
|
||||
from .reward_function import RewardFunction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _krippendorff_alpha(ratings_matrix: np.ndarray) -> float:
|
||||
"""
|
||||
Compute Krippendorff's alpha for inter-rater reliability.
|
||||
|
||||
Uses the interval/ratio metric (squared differences).
|
||||
|
||||
Args:
|
||||
ratings_matrix: Shape (n_raters, n_items). NaN values indicate
|
||||
missing ratings and are excluded from computation.
|
||||
|
||||
Returns:
|
||||
Alpha value in [-1, 1]. 1 = perfect agreement, 0 = chance agreement,
|
||||
negative = systematic disagreement.
|
||||
"""
|
||||
n_raters, n_items = ratings_matrix.shape
|
||||
|
||||
if n_raters < 2 or n_items < 2:
|
||||
return float("nan")
|
||||
|
||||
# Build coincidence matrix approach using pairwise disagreements
|
||||
# For each item, compute observed disagreement across all rater pairs
|
||||
observed_disagreement = 0.0
|
||||
total_pairs = 0
|
||||
|
||||
for item_idx in range(n_items):
|
||||
values = ratings_matrix[:, item_idx]
|
||||
valid = values[~np.isnan(values)]
|
||||
n_valid = len(valid)
|
||||
if n_valid < 2:
|
||||
continue
|
||||
|
||||
# Sum of squared differences for all pairs within this item
|
||||
for i in range(n_valid):
|
||||
for j in range(i + 1, n_valid):
|
||||
observed_disagreement += (valid[i] - valid[j]) ** 2
|
||||
total_pairs += 1
|
||||
|
||||
if total_pairs == 0:
|
||||
return float("nan")
|
||||
|
||||
observed_disagreement /= total_pairs
|
||||
|
||||
# Expected disagreement: pairwise differences across ALL values
|
||||
all_valid = ratings_matrix[~np.isnan(ratings_matrix)]
|
||||
n_all = len(all_valid)
|
||||
if n_all < 2:
|
||||
return float("nan")
|
||||
|
||||
expected_disagreement = 0.0
|
||||
expected_pairs = 0
|
||||
for i in range(n_all):
|
||||
for j in range(i + 1, n_all):
|
||||
expected_disagreement += (all_valid[i] - all_valid[j]) ** 2
|
||||
expected_pairs += 1
|
||||
|
||||
if expected_pairs == 0:
|
||||
return float("nan")
|
||||
|
||||
expected_disagreement /= expected_pairs
|
||||
|
||||
if expected_disagreement == 0.0:
|
||||
# All raters gave identical scores -- perfect agreement
|
||||
return 1.0
|
||||
|
||||
alpha = 1.0 - (observed_disagreement / expected_disagreement)
|
||||
return float(alpha)
|
||||
|
||||
|
||||
@registry.register
|
||||
class EnsembleReward(RewardFunction):
|
||||
"""
|
||||
Ensemble reward function that aggregates multiple reward functions
|
||||
with robust strategies and inter-rater reliability tracking.
|
||||
|
||||
Compared to CombinedReward, this adds:
|
||||
- Median and min (conservative) aggregation for robustness
|
||||
- Majority vote for binary reward environments
|
||||
- Krippendorff's alpha inter-rater reliability metric
|
||||
- Per-item disagreement tracking for reward hacking detection
|
||||
|
||||
Strategies:
|
||||
- "mean": Weighted average (same as CombinedReward)
|
||||
- "median": Median across reward functions (robust to outliers)
|
||||
- "min": Conservative -- use the minimum score (prevents reward hacking)
|
||||
- "majority_vote": For binary rewards -- majority wins (ties -> positive)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rewards: List[Union[str, Dict]],
|
||||
strategy: str = "mean",
|
||||
weight: float = 1.0,
|
||||
track_disagreement: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the ensemble reward function.
|
||||
|
||||
Args:
|
||||
rewards: List of reward function names or config dicts.
|
||||
Resolved via RewardRegistry.
|
||||
strategy: Aggregation strategy. One of: "mean", "median",
|
||||
"min", "majority_vote".
|
||||
weight: Weight for this ensemble when used inside another
|
||||
CombinedReward.
|
||||
track_disagreement: If True, track per-item reward variance
|
||||
for disagreement analysis.
|
||||
**kwargs: Additional parameters passed to RewardFunction.
|
||||
"""
|
||||
super().__init__(weight=weight, **kwargs)
|
||||
|
||||
valid_strategies = {"mean", "median", "min", "majority_vote"}
|
||||
if strategy not in valid_strategies:
|
||||
raise ValueError(
|
||||
f"Invalid strategy '{strategy}'. Must be one of: {valid_strategies}"
|
||||
)
|
||||
|
||||
self.strategy = strategy
|
||||
self.track_disagreement = track_disagreement
|
||||
self.reward_functions: List[RewardFunction] = []
|
||||
|
||||
# Initialize sub-reward functions via registry
|
||||
for reward_config in rewards:
|
||||
self.reward_functions.append(registry.create(reward_config))
|
||||
|
||||
if len(self.reward_functions) < 2:
|
||||
warnings.warn(
|
||||
"EnsembleReward initialized with fewer than 2 reward functions. "
|
||||
"Inter-rater reliability metrics will not be meaningful.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# State for reliability tracking
|
||||
self.last_reliability_alpha: float = float("nan")
|
||||
self.last_disagreement_scores: Optional[List[float]] = None
|
||||
self._all_sub_rewards: Optional[List[List[float]]] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
sub_names = ",".join(r.name for r in self.reward_functions)
|
||||
return f"ensemble_{self.strategy}({sub_names})"
|
||||
|
||||
def set_wandb_logger(self, wandb_logger):
|
||||
"""Propagate WandB logger to all sub-reward functions."""
|
||||
super().set_wandb_logger(wandb_logger)
|
||||
for reward_fn in self.reward_functions:
|
||||
reward_fn.set_wandb_logger(wandb_logger)
|
||||
|
||||
def compute(self, completions: List[Any], **kwargs) -> List[float]:
|
||||
"""
|
||||
Compute ensemble reward scores.
|
||||
|
||||
Calls all sub-reward functions, aggregates by strategy,
|
||||
and computes reliability metrics.
|
||||
|
||||
Args:
|
||||
completions: List of completions to evaluate.
|
||||
**kwargs: Additional context passed to sub-rewards.
|
||||
|
||||
Returns:
|
||||
Aggregated reward scores, one per completion.
|
||||
"""
|
||||
if not completions:
|
||||
return []
|
||||
|
||||
n_completions = len(completions)
|
||||
|
||||
# Collect all sub-reward scores
|
||||
all_rewards: List[List[float]] = []
|
||||
for reward_fn in self.reward_functions:
|
||||
try:
|
||||
scores = reward_fn.compute(completions, **kwargs)
|
||||
if len(scores) != n_completions:
|
||||
logger.warning(
|
||||
"Reward function %s returned %d scores for %d completions. "
|
||||
"Padding/truncating.",
|
||||
reward_fn.name,
|
||||
len(scores),
|
||||
n_completions,
|
||||
)
|
||||
# Pad or truncate
|
||||
if len(scores) < n_completions:
|
||||
scores = scores + [0.0] * (n_completions - len(scores))
|
||||
else:
|
||||
scores = scores[:n_completions]
|
||||
all_rewards.append(scores)
|
||||
except Exception as e:
|
||||
logger.error("Error in reward function %s: %s", reward_fn.name, e)
|
||||
all_rewards.append([0.0] * n_completions)
|
||||
|
||||
self._all_sub_rewards = all_rewards
|
||||
|
||||
if not all_rewards:
|
||||
return [0.0] * n_completions
|
||||
|
||||
# Convert to numpy for efficient aggregation
|
||||
# Shape: (n_reward_fns, n_completions)
|
||||
reward_matrix = np.array(all_rewards, dtype=np.float64)
|
||||
|
||||
# Aggregate by strategy
|
||||
if self.strategy == "mean":
|
||||
aggregated = np.mean(reward_matrix, axis=0)
|
||||
elif self.strategy == "median":
|
||||
aggregated = np.median(reward_matrix, axis=0)
|
||||
elif self.strategy == "min":
|
||||
aggregated = np.min(reward_matrix, axis=0)
|
||||
elif self.strategy == "majority_vote":
|
||||
# Treat positive as vote for 1, non-positive as vote for 0
|
||||
votes = (reward_matrix > 0).astype(np.float64)
|
||||
vote_fractions = np.mean(votes, axis=0)
|
||||
# Majority wins; ties (0.5) go to positive
|
||||
aggregated = np.where(vote_fractions >= 0.5, 1.0, 0.0)
|
||||
else:
|
||||
# Should not reach here due to __init__ validation
|
||||
aggregated = np.mean(reward_matrix, axis=0)
|
||||
|
||||
# Compute reliability metrics
|
||||
self._compute_reliability_metrics(reward_matrix)
|
||||
|
||||
# Track per-item disagreement
|
||||
if self.track_disagreement:
|
||||
self.last_disagreement_scores = np.var(reward_matrix, axis=0).tolist()
|
||||
|
||||
return aggregated.tolist()
|
||||
|
||||
def _compute_reliability_metrics(self, reward_matrix: np.ndarray):
|
||||
"""
|
||||
Compute and store inter-rater reliability metrics.
|
||||
|
||||
Args:
|
||||
reward_matrix: Shape (n_raters, n_items)
|
||||
"""
|
||||
n_raters, n_items = reward_matrix.shape
|
||||
|
||||
if n_raters < 2 or n_items < 2:
|
||||
self.last_reliability_alpha = float("nan")
|
||||
return
|
||||
|
||||
self.last_reliability_alpha = _krippendorff_alpha(reward_matrix)
|
||||
|
||||
def reliability_metrics(self) -> Dict[str, float]:
|
||||
"""
|
||||
Return the latest inter-rater reliability metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary with reliability statistics:
|
||||
- alpha: Krippendorff's alpha
|
||||
- mean_disagreement: Average per-item variance across raters
|
||||
- max_disagreement: Maximum per-item variance (worst agreement)
|
||||
"""
|
||||
metrics = {
|
||||
"alpha": self.last_reliability_alpha,
|
||||
}
|
||||
|
||||
if self.last_disagreement_scores is not None:
|
||||
scores = self.last_disagreement_scores
|
||||
metrics["mean_disagreement"] = sum(scores) / len(scores) if scores else 0.0
|
||||
metrics["max_disagreement"] = max(scores) if scores else 0.0
|
||||
|
||||
return metrics
|
||||
|
||||
def log_metrics(self, raw_rewards: List[float], weighted_rewards: List[float]):
|
||||
"""Log ensemble-specific metrics alongside standard reward metrics."""
|
||||
super().log_metrics(raw_rewards, weighted_rewards)
|
||||
|
||||
if not self.wandb_logger:
|
||||
return
|
||||
|
||||
reliability = self.reliability_metrics()
|
||||
wandb_metrics = {}
|
||||
|
||||
if not np.isnan(reliability.get("alpha", float("nan"))):
|
||||
wandb_metrics[f"reward/{self.name}/reliability_alpha"] = reliability[
|
||||
"alpha"
|
||||
]
|
||||
|
||||
if "mean_disagreement" in reliability:
|
||||
wandb_metrics[f"reward/{self.name}/mean_disagreement"] = reliability[
|
||||
"mean_disagreement"
|
||||
]
|
||||
wandb_metrics[f"reward/{self.name}/max_disagreement"] = reliability[
|
||||
"max_disagreement"
|
||||
]
|
||||
|
||||
if wandb_metrics:
|
||||
self.wandb_logger.log(wandb_metrics)
|
||||
310
atroposlib/tests/test_reward_ensemble.py
Normal file
310
atroposlib/tests/test_reward_ensemble.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
"""
|
||||
Tests for EnsembleReward -- reward aggregation with inter-rater reliability.
|
||||
|
||||
Tests cover:
|
||||
- All aggregation strategies (mean, median, min, majority_vote)
|
||||
- Krippendorff's alpha computation (perfect/no agreement)
|
||||
- Disagreement tracking
|
||||
- Registry integration
|
||||
- Edge cases (empty completions, single reward function)
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.reward_fns.ensemble_reward import (
|
||||
EnsembleReward,
|
||||
_krippendorff_alpha,
|
||||
)
|
||||
from atroposlib.envs.reward_fns.registry import RewardRegistry
|
||||
from atroposlib.envs.reward_fns.reward_function import RewardFunction
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test fixtures -- simple reward functions for composing ensembles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConstantReward(RewardFunction):
|
||||
"""Returns a fixed score for every completion."""
|
||||
|
||||
def __init__(self, value: float = 1.0, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._value = value
|
||||
|
||||
def compute(self, completions: List[Any], **kwargs) -> List[float]:
|
||||
return [self._value] * len(completions)
|
||||
|
||||
|
||||
class LengthReward(RewardFunction):
|
||||
"""Scores by string length (for testing divergent reward signals)."""
|
||||
|
||||
def compute(self, completions: List[Any], **kwargs) -> List[float]:
|
||||
return [float(len(self.get_content(c))) for c in completions]
|
||||
|
||||
|
||||
class BinaryReward(RewardFunction):
|
||||
"""Returns 1.0 if completion contains 'good', else 0.0."""
|
||||
|
||||
def compute(self, completions: List[Any], **kwargs) -> List[float]:
|
||||
return [
|
||||
1.0 if "good" in self.get_content(c).lower() else 0.0 for c in completions
|
||||
]
|
||||
|
||||
|
||||
def _make_ensemble(strategy, reward_functions):
|
||||
"""Helper to construct an EnsembleReward without going through registry."""
|
||||
ensemble = EnsembleReward.__new__(EnsembleReward)
|
||||
ensemble.weight = 1.0
|
||||
ensemble.strategy = strategy
|
||||
ensemble.track_disagreement = True
|
||||
ensemble.reward_functions = reward_functions
|
||||
ensemble.wandb_logger = None
|
||||
ensemble._name = None
|
||||
ensemble.config = {}
|
||||
ensemble.last_reliability_alpha = float("nan")
|
||||
ensemble.last_disagreement_scores = None
|
||||
ensemble._all_sub_rewards = None
|
||||
return ensemble
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_registry():
|
||||
"""Create a clean registry with test reward functions."""
|
||||
reg = RewardRegistry()
|
||||
reg.register(name="constant")(ConstantReward)
|
||||
reg.register(name="length")(LengthReward)
|
||||
reg.register(name="binary")(BinaryReward)
|
||||
return reg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def completions():
|
||||
"""Sample completions for testing."""
|
||||
return ["short", "a medium length string", "good answer here"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Aggregation strategy tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMeanAggregation:
|
||||
def test_mean_of_identical_scores(self, completions):
|
||||
ensemble = _make_ensemble(
|
||||
"mean",
|
||||
[
|
||||
ConstantReward(value=2.0),
|
||||
ConstantReward(value=2.0),
|
||||
],
|
||||
)
|
||||
scores = ensemble.compute(completions)
|
||||
assert len(scores) == 3
|
||||
assert all(math.isclose(s, 2.0, rel_tol=1e-9) for s in scores)
|
||||
|
||||
def test_mean_of_different_scores(self, completions):
|
||||
ensemble = _make_ensemble(
|
||||
"mean",
|
||||
[
|
||||
ConstantReward(value=1.0),
|
||||
ConstantReward(value=3.0),
|
||||
],
|
||||
)
|
||||
scores = ensemble.compute(completions)
|
||||
assert all(math.isclose(s, 2.0, rel_tol=1e-9) for s in scores)
|
||||
|
||||
|
||||
class TestMedianAggregation:
|
||||
def test_median_rejects_outlier(self, completions):
|
||||
"""Median should be robust to a single outlier reward function."""
|
||||
ensemble = _make_ensemble(
|
||||
"median",
|
||||
[
|
||||
ConstantReward(value=1.0),
|
||||
ConstantReward(value=1.0),
|
||||
ConstantReward(value=100.0),
|
||||
],
|
||||
)
|
||||
scores = ensemble.compute(completions)
|
||||
assert all(math.isclose(s, 1.0, rel_tol=1e-9) for s in scores)
|
||||
|
||||
|
||||
class TestMinAggregation:
|
||||
def test_min_is_conservative(self, completions):
|
||||
ensemble = _make_ensemble(
|
||||
"min",
|
||||
[
|
||||
ConstantReward(value=0.5),
|
||||
ConstantReward(value=0.8),
|
||||
ConstantReward(value=1.0),
|
||||
],
|
||||
)
|
||||
scores = ensemble.compute(completions)
|
||||
assert all(math.isclose(s, 0.5, rel_tol=1e-9) for s in scores)
|
||||
|
||||
|
||||
class TestMajorityVoteAggregation:
|
||||
def test_majority_positive(self, completions):
|
||||
ensemble = _make_ensemble(
|
||||
"majority_vote",
|
||||
[
|
||||
ConstantReward(value=1.0),
|
||||
ConstantReward(value=1.0),
|
||||
ConstantReward(value=-1.0),
|
||||
],
|
||||
)
|
||||
scores = ensemble.compute(completions)
|
||||
assert all(math.isclose(s, 1.0) for s in scores)
|
||||
|
||||
def test_majority_negative(self, completions):
|
||||
ensemble = _make_ensemble(
|
||||
"majority_vote",
|
||||
[
|
||||
ConstantReward(value=-1.0),
|
||||
ConstantReward(value=-1.0),
|
||||
ConstantReward(value=1.0),
|
||||
],
|
||||
)
|
||||
scores = ensemble.compute(completions)
|
||||
assert all(math.isclose(s, 0.0) for s in scores)
|
||||
|
||||
def test_tie_goes_positive(self, completions):
|
||||
ensemble = _make_ensemble(
|
||||
"majority_vote",
|
||||
[
|
||||
ConstantReward(value=1.0),
|
||||
ConstantReward(value=-1.0),
|
||||
],
|
||||
)
|
||||
scores = ensemble.compute(completions)
|
||||
assert all(math.isclose(s, 1.0) for s in scores)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inter-rater reliability tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKrippendorffAlpha:
|
||||
def test_perfect_agreement(self):
|
||||
ratings = np.array(
|
||||
[
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
]
|
||||
)
|
||||
alpha = _krippendorff_alpha(ratings)
|
||||
assert math.isclose(alpha, 1.0, rel_tol=1e-9)
|
||||
|
||||
def test_no_agreement(self):
|
||||
ratings = np.array(
|
||||
[
|
||||
[1.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 1.0, 0.0, 1.0],
|
||||
]
|
||||
)
|
||||
alpha = _krippendorff_alpha(ratings)
|
||||
assert alpha < 0.0
|
||||
|
||||
def test_random_agreement(self):
|
||||
np.random.seed(42)
|
||||
ratings = np.random.rand(5, 100)
|
||||
alpha = _krippendorff_alpha(ratings)
|
||||
assert abs(alpha) < 0.3
|
||||
|
||||
def test_insufficient_data(self):
|
||||
alpha = _krippendorff_alpha(np.array([[1.0, 2.0, 3.0]]))
|
||||
assert math.isnan(alpha)
|
||||
|
||||
alpha = _krippendorff_alpha(np.array([[1.0], [2.0]]))
|
||||
assert math.isnan(alpha)
|
||||
|
||||
|
||||
class TestReliabilityMetrics:
|
||||
def test_reliability_computed_after_scoring(self, completions):
|
||||
ensemble = _make_ensemble(
|
||||
"mean",
|
||||
[
|
||||
ConstantReward(value=1.0),
|
||||
ConstantReward(value=1.0),
|
||||
],
|
||||
)
|
||||
ensemble.compute(completions)
|
||||
metrics = ensemble.reliability_metrics()
|
||||
assert "alpha" in metrics
|
||||
assert math.isclose(metrics["alpha"], 1.0, rel_tol=1e-9)
|
||||
|
||||
def test_disagreement_tracked(self, completions):
|
||||
ensemble = _make_ensemble(
|
||||
"mean",
|
||||
[
|
||||
ConstantReward(value=0.0),
|
||||
ConstantReward(value=10.0),
|
||||
],
|
||||
)
|
||||
ensemble.compute(completions)
|
||||
assert ensemble.last_disagreement_scores is not None
|
||||
assert len(ensemble.last_disagreement_scores) == len(completions)
|
||||
# Variance of [0.0, 10.0] = 25.0
|
||||
assert all(
|
||||
math.isclose(d, 25.0, rel_tol=1e-9)
|
||||
for d in ensemble.last_disagreement_scores
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryIntegration:
|
||||
def test_create_via_registry(self, test_registry):
|
||||
# EnsembleReward.__init__ resolves sub-rewards via the global registry,
|
||||
# so we must register our test fixtures there too.
|
||||
from atroposlib.envs.reward_fns.registry import registry as global_registry
|
||||
|
||||
global_registry.register(name="test_constant")(ConstantReward)
|
||||
test_registry.register(name="ensemble")(EnsembleReward)
|
||||
|
||||
ensemble = test_registry.create(
|
||||
{
|
||||
"type": "ensemble",
|
||||
"rewards": ["test_constant", "test_constant"],
|
||||
"strategy": "median",
|
||||
}
|
||||
)
|
||||
assert isinstance(ensemble, EnsembleReward)
|
||||
assert ensemble.strategy == "median"
|
||||
assert len(ensemble.reward_functions) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_empty_completions(self):
|
||||
ensemble = _make_ensemble("mean", [ConstantReward(value=1.0)])
|
||||
scores = ensemble.compute([])
|
||||
assert scores == []
|
||||
|
||||
def test_invalid_strategy_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid strategy"):
|
||||
EnsembleReward(rewards=[], strategy="nonexistent")
|
||||
|
||||
def test_name_format(self):
|
||||
ensemble = _make_ensemble(
|
||||
"median",
|
||||
[
|
||||
ConstantReward(value=1.0),
|
||||
LengthReward(),
|
||||
],
|
||||
)
|
||||
name = ensemble.name
|
||||
assert "ensemble_median" in name
|
||||
assert "constantreward" in name
|
||||
assert "lengthreward" in name
|
||||
Loading…
Add table
Add a link
Reference in a new issue