diff --git a/reasoning_gym/coaching/__init__.py b/reasoning_gym/coaching/__init__.py index 97e89844..2b4927a8 100644 --- a/reasoning_gym/coaching/__init__.py +++ b/reasoning_gym/coaching/__init__.py @@ -1,6 +1,8 @@ from .attributes import AttributeDefinition, RangeAttributeDefinition, ScalarAttributeDefinition from .base_curriculum import BaseCurriculum -from .coach import Coach, GroupedScores, ScoreBoard, ScoreStats +from .curriculum_config import CurriculumAttributeConfig, CurriculumExperimentConfig +from .experiment import CurriculumExperiment, Experiment +from .score_board import GroupedScores, ScoreBoard, ScoreStats __all__ = [ "AttributeType", @@ -8,8 +10,11 @@ __all__ = [ "ScalarAttributeDefinition", "RangeAttributeDefinition", "BaseCurriculum", - "Coach", "ScoreBoard", "GroupedScores", "ScoreStats", + "Experiment", + "CurriculumExperiment", + "CurriculumAttributeConfig", + "CurriculumExperimentConfig", ] diff --git a/reasoning_gym/coaching/experiment.py b/reasoning_gym/coaching/experiment.py index 7b15db8c..bb6b74d0 100644 --- a/reasoning_gym/coaching/experiment.py +++ b/reasoning_gym/coaching/experiment.py @@ -7,8 +7,8 @@ from reasoning_gym.coaching.base_curriculum import CurriculumContext from ..composite import CompositeConfig, CompositeDataset, DatasetSpec from ..factory import create_curriculum from ..version_manager import DatasetVersionManager -from .coach import ScoreBoard from .curriculum_config import CurriculumExperimentConfig +from .score_board import ScoreBoard class Experiment: diff --git a/reasoning_gym/coaching/coach.py b/reasoning_gym/coaching/score_board.py similarity index 75% rename from reasoning_gym/coaching/coach.py rename to reasoning_gym/coaching/score_board.py index f1bf39b9..69413473 100644 --- a/reasoning_gym/coaching/coach.py +++ b/reasoning_gym/coaching/score_board.py @@ -191,63 +191,3 @@ class ScoreBoard: total_scores = sum(len(scores) for scores in result.values()) return GroupedScores(scores=result, total_scores=total_scores) - - -class Coach(ProceduralDataset): - """A dataset wrapper that tracks performance and adjusts difficulty - - The Coach wraps a ProceduralDataset (typically a CompositeDataset) and: - 1. Tracks scores and metadata in a ScoreBoard - 2. Adjusts difficulty based on performance (to be implemented) - """ - - def __init__(self, dataset: ProceduralDataset, score_log: Optional[Union[str, Path]] = None): - """Initialize with inner dataset - - Args: - dataset: The ProceduralDataset to wrap - score_log: Optional path to jsonl file for logging scores - """ - super().__init__(config=dataset.config, seed=dataset.seed, size=dataset.size) - self.dataset = dataset - self.score_board = ScoreBoard() - self.score_log = Path(score_log) if score_log else None - - def __getitem__(self, idx: int) -> dict: - """Forward item generation to inner dataset""" - return self.dataset[idx] - - def score_answer( - self, answer: Optional[str], entry: dict[str, Any], conversation: Optional[list[dict]] = None - ) -> float: - """Score answer and track results - - Args: - answer: The answer to score - entry: The task entry containing question/answer/metadata - conversation: Optional conversation history as list of message dicts - - Returns: - float: Score between 0.0 and 1.0 - """ - # Get score from inner dataset - score = self.dataset.score_answer(answer, entry) - - # Track score and metadata - self.score_board.add_score(score=score, metadata=entry["metadata"], conversation=conversation) - - # Log score if logging is enabled - if self.score_log is not None: - log_entry = {"score": score, "answer": answer, "entry": entry, "conversation": conversation} - with self.score_log.open("a") as f: - json.dump(log_entry, f) - f.write("\n") - - return score - - def update_difficulty(self) -> None: - """Update difficulty based on recent performance - - To be implemented in future versions. - """ - pass # Placeholder for future difficulty adjustment logic diff --git a/tests/test_coaching.py b/tests/test_score_board.py similarity index 55% rename from tests/test_coaching.py rename to tests/test_score_board.py index 1741e87a..26fd474a 100644 --- a/tests/test_coaching.py +++ b/tests/test_score_board.py @@ -1,4 +1,3 @@ -import json import math from collections import OrderedDict @@ -6,26 +5,40 @@ import pytest from reasoning_gym.arithmetic.chain_sum import ChainSumConfig, ChainSumDataset from reasoning_gym.arithmetic.leg_counting import LegCountingConfig -from reasoning_gym.coaching import Coach, GroupedScores +from reasoning_gym.coaching import ( + CurriculumAttributeConfig, + CurriculumExperiment, + CurriculumExperimentConfig, + GroupedScores, +) +from reasoning_gym.coaching.base_curriculum import DefaultCurriculumContext, RangeAttributeMode from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec -def test_coach_with_chain_sum(): - # Create a small ChainSum dataset - config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) - dataset = ChainSumDataset(config) - coach = Coach(dataset) +def test_score_aggregation(): + config = CurriculumExperimentConfig( + curricula={"leg_counting": CurriculumAttributeConfig(attribute_levels={"num_animals": 2}, weight=1.0)} + ) + + # Create experiment + experiment = CurriculumExperiment( + name="test_experiment", + config=config, + context=DefaultCurriculumContext(mode=RangeAttributeMode.INCLUSIVE), + size=10, + seed=42, + ) # Simulate an agent working on tasks for i in range(5): - item = coach[i] + item = experiment.get_dataset_entry(i) # Simulate some correct and incorrect answers if i % 2 == 0: # Correct answer - score = coach.score_answer( + score = experiment.score_answer_with_id( answer=item["answer"], - entry=item, + entry_id=item["metadata"]["entry_id"], conversation=[ {"role": "user", "content": item["question"]}, {"role": "assistant", "content": item["answer"]}, @@ -34,18 +47,18 @@ def test_coach_with_chain_sum(): assert score == 1.0 else: # Incorrect answer (None) - score = coach.score_answer( + score = experiment.score_answer_with_id( answer=None, - entry=item, + entry_id=item["metadata"]["entry_id"], conversation=[ {"role": "user", "content": item["question"]}, - {"role": "assistant", "content": "I don't know"}, + {"role": "assistant", "content": item["answer"]}, ], ) assert score == 0.0 # Test score aggregation - aggregated = coach.score_board.aggregate() + aggregated = experiment.score_board.aggregate() # Verify we have scores grouped by difficulty parameters assert len(aggregated.scores) > 0 @@ -56,18 +69,18 @@ def test_coach_with_chain_sum(): # Each inner tuple should be (param_name, value) or (param_name, (min_value, max_value)) for param in key: assert isinstance(param, tuple) - assert param[0] in ("source", "idx", "num_terms", "num_digits") + assert param[0] in ("source", "idx", "num_animals", "num_instances") # Test aggregation with last_n - last_3 = coach.score_board.aggregate(last_n=3) + last_3 = experiment.score_board.aggregate(last_n=3) assert len(last_3.scores) > 0 # Verify total scores count assert last_3.total_scores == 3 # Verify conversation tracking - assert len(coach.score_board.conversations) == 5 - for conv in coach.score_board.conversations: + assert len(experiment.score_board.conversations) == 5 + for conv in experiment.score_board.conversations: assert len(conv) == 2 # user question and assistant response assert conv[0]["role"] == "user" assert conv[1]["role"] == "assistant" @@ -93,43 +106,38 @@ def test_coach_with_chain_sum(): assert stats_tuple[0] == 0 # count should be 0 for empty list assert all(math.isnan(v) for v in stats_tuple[1:]) # stats should be NaN - print(aggregated) - print(stats) - # Test clear functionality - coach.score_board.clear() - assert len(coach.score_board.scores) == 0 - assert len(coach.score_board.metadata) == 0 - assert len(coach.score_board.conversations) == 0 - assert len(coach.score_board.aggregate().scores) == 0 + experiment.score_board.clear() + assert len(experiment.score_board.scores) == 0 + assert len(experiment.score_board.metadata) == 0 + assert len(experiment.score_board.conversations) == 0 + assert len(experiment.score_board.aggregate().scores) == 0 -def test_coach_with_composite(): +def test_experiment_with_composite(): # Create configs for both datasets - chain_sum_config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10) - leg_counting_config = LegCountingConfig(min_animals=2, max_animals=3, size=10) - - # Create composite config - composite_config = CompositeConfig( - size=20, - seed=42, - datasets=[ - DatasetSpec(name="chain_sum", weight=1.0, config=chain_sum_config.__dict__), - DatasetSpec(name="leg_counting", weight=1.0, config=leg_counting_config.__dict__), - ], + config = CurriculumExperimentConfig( + curricula={ + "chain_sum": CurriculumAttributeConfig(attribute_levels={"num_terms": 2}, weight=1.0), + "leg_counting": CurriculumAttributeConfig(attribute_levels={"num_animals": 2}, weight=1.0), + } + ) + # Create experiment + experiment = CurriculumExperiment( + name="test_experiment", + config=config, + context=DefaultCurriculumContext(mode=RangeAttributeMode.INCLUSIVE), + size=10, + seed=42, ) - - # Create composite dataset and coach - dataset = CompositeDataset(composite_config) - coach = Coach(dataset) # Score some answers for i in range(5): - item = coach[i] + item = experiment.get_dataset_entry(i) # Correct answers for even indices - score = coach.score_answer( + score = experiment.score_answer_with_id( answer=item["answer"] if i % 2 == 0 else None, - entry=item, + entry_id=item["metadata"]["entry_id"], conversation=[ {"role": "user", "content": item["question"]}, {"role": "assistant", "content": item["answer"] if i % 2 == 0 else "I don't know"}, @@ -138,7 +146,7 @@ def test_coach_with_composite(): assert score in (0.0, 1.0) # Test aggregation - aggregated = coach.score_board.aggregate() + aggregated = experiment.score_board.aggregate() assert len(aggregated.scores) > 0 # Verify source dataset info is first in keys @@ -154,24 +162,6 @@ def test_coach_with_composite(): assert isinstance(values[0], int) assert all(isinstance(v, float) for v in values[1:]) - print("\nComposite Dataset Stats:") - print(stats) - - # Test config update - coach.dataset.update_dataset_config("chain_sum", {"min_terms": 4, "max_terms": 5}) - - # Verify the config was updated - chain_sum_dataset = coach.dataset.datasets["chain_sum"] - assert chain_sum_dataset.config.min_terms == 4 - assert chain_sum_dataset.config.max_terms == 5 - - # Score some more items to verify new config is in effect - for i in range(3): - item = coach[i + 5] # Use different indices - if "chain_sum" in item["metadata"]["source_dataset"]: - metadata = item["metadata"] - assert metadata["num_terms"] >= 4 - def test_grouped_scores_str(): # Test raw scores string representation @@ -198,41 +188,3 @@ def test_grouped_scores_str(): # Test empty scores empty = GroupedScores(scores=OrderedDict(), total_scores=0) assert str(empty) == "No scores recorded" - - -def test_coach_score_logging(tmp_path): - # Create a log file in the temporary directory - log_file = tmp_path / "scores.jsonl" - - # Create dataset and coach with logging - config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) - dataset = ChainSumDataset(config) - coach = Coach(dataset, score_log=log_file) - - # Score a few answers - for i in range(3): - item = coach[i] - coach.score_answer( - answer=item["answer"] if i % 2 == 0 else None, - entry=item, - conversation=[ - {"role": "user", "content": item["question"]}, - {"role": "assistant", "content": item["answer"] if i % 2 == 0 else "I don't know"}, - ], - ) - - # Verify log file contents - assert log_file.exists() - - # Read and parse log entries - log_entries = [json.loads(line) for line in log_file.open()] - assert len(log_entries) == 3 - - # Verify log entry structure - for i, entry in enumerate(log_entries): - assert "score" in entry - assert "entry" in entry - assert "metadata" in entry["entry"] - assert "conversation" in entry - assert entry["score"] == (1.0 if i % 2 == 0 else 0.0) - assert len(entry["conversation"]) == 2