reasoning-gym/tests/test_coaching.py

239 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import math
from collections import OrderedDict
import pytest
from reasoning_gym.arithmetic.chain_sum import ChainSumConfig, ChainSumDataset
from reasoning_gym.arithmetic.leg_counting import LegCountingConfig
from reasoning_gym.coaching import Coach, GroupedScores
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
def test_coach_with_chain_sum():
# Create a small ChainSum dataset
config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42)
dataset = ChainSumDataset(config)
coach = Coach(dataset)
# Simulate an agent working on tasks
for i in range(5):
item = coach[i]
# Simulate some correct and incorrect answers
if i % 2 == 0:
# Correct answer
score = coach.score_answer(
answer=item["answer"],
entry=item,
conversation=[
{"role": "user", "content": item["question"]},
{"role": "assistant", "content": item["answer"]},
],
)
assert score == 1.0
else:
# Incorrect answer (None)
score = coach.score_answer(
answer=None,
entry=item,
conversation=[
{"role": "user", "content": item["question"]},
{"role": "assistant", "content": "I don't know"},
],
)
assert score == 0.0
# Test score aggregation
aggregated = coach.score_board.aggregate()
# Verify we have scores grouped by difficulty parameters
assert len(aggregated.scores) > 0
# Each key should be a tuple of tuples containing difficulty parameters
for key in aggregated.scores:
assert isinstance(key, tuple)
# Each inner tuple should be (param_name, value)
for param in key:
assert isinstance(param, tuple)
assert param[0] in ("num_terms", "num_digits")
assert isinstance(param[1], int)
# Test aggregation with last_n
last_3 = coach.score_board.aggregate(last_n=3)
assert len(last_3.scores) > 0
# Verify total scores count
assert last_3.total_scores == 3
# Verify conversation tracking
assert len(coach.score_board.conversations) == 5
for conv in coach.score_board.conversations:
assert len(conv) == 2 # user question and assistant response
assert conv[0]["role"] == "user"
assert conv[1]["role"] == "assistant"
# Test stats calculation
stats = aggregated.stats()
for key, values in stats.scores.items():
assert isinstance(values, tuple)
assert len(values) == 5 # (count, mean, std, min, max)
assert isinstance(values[0], int) # count should be int
assert all(isinstance(v, float) for v in values[1:]) # stats should be floats
# Test stats with empty scores
empty_stats = GroupedScores(scores=OrderedDict(), total_scores=0).stats()
assert len(empty_stats.scores) == 0
# Test stats with ignore_empty=False
empty_group = OrderedDict({(("test", 1),): []})
non_ignoring_stats = GroupedScores(scores=empty_group, total_scores=0).stats(ignore_empty=False)
assert len(non_ignoring_stats.scores) == 1
stats_tuple = next(iter(non_ignoring_stats.scores.values()))
assert stats_tuple[0] == 0 # count should be 0 for empty list
assert all(math.isnan(v) for v in stats_tuple[1:]) # stats should be NaN
print(aggregated)
print(stats)
# Test clear functionality
coach.score_board.clear()
assert len(coach.score_board.scores) == 0
assert len(coach.score_board.metadata) == 0
assert len(coach.score_board.conversations) == 0
assert len(coach.score_board.aggregate().scores) == 0
def test_coach_with_composite():
# Create configs for both datasets
chain_sum_config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10)
leg_counting_config = LegCountingConfig(min_animals=2, max_animals=3, size=10)
# Create composite config
composite_config = CompositeConfig(
size=20,
seed=42,
datasets=[
DatasetSpec(name="chain_sum", weight=1.0, config=chain_sum_config.__dict__),
DatasetSpec(name="leg_counting", weight=1.0, config=leg_counting_config.__dict__),
],
)
# Create composite dataset and coach
dataset = CompositeDataset(composite_config)
coach = Coach(dataset)
# Score some answers
for i in range(5):
item = coach[i]
# Correct answers for even indices
score = coach.score_answer(
answer=item["answer"] if i % 2 == 0 else None,
entry=item,
conversation=[
{"role": "user", "content": item["question"]},
{"role": "assistant", "content": item["answer"] if i % 2 == 0 else "I don't know"},
],
)
assert score in (0.0, 1.0)
# Test aggregation
aggregated = coach.score_board.aggregate()
assert len(aggregated.scores) > 0
# Verify source dataset info is first in keys
for key in aggregated.scores:
assert key[0][0] == "source" # First tuple should be ("source", dataset_name)
assert key[1][0] == "idx" # Second tuple should be ("idx", index)
# Test stats
stats = aggregated.stats()
for key, values in stats.scores.items():
assert isinstance(values, tuple)
assert len(values) == 5 # (count, mean, std, min, max)
assert isinstance(values[0], int)
assert all(isinstance(v, float) for v in values[1:])
print("\nComposite Dataset Stats:")
print(stats)
# Test config update
coach.dataset.update_dataset_config("chain_sum", {"min_terms": 4, "max_terms": 5})
# Verify the config was updated
chain_sum_dataset = coach.dataset.datasets["chain_sum"]
assert chain_sum_dataset.config.min_terms == 4
assert chain_sum_dataset.config.max_terms == 5
# Score some more items to verify new config is in effect
for i in range(3):
item = coach[i + 5] # Use different indices
if "chain_sum" in item["metadata"]["source_dataset"]:
metadata = item["metadata"]
assert metadata["difficulty"]["num_terms"] >= 4
def test_grouped_scores_str():
# Test raw scores string representation
scores = OrderedDict()
scores[(("num_terms", 2), ("num_digits", 1))] = [1.0, 0.0, 1.0]
scores[(("num_terms", 3), ("num_digits", 2))] = [0.5, 0.5]
grouped = GroupedScores(scores=scores, total_scores=5)
report = str(grouped)
assert "Total scores: 5" in report
assert "(num_terms=2, num_digits=1): n=3" in report
assert "(num_terms=3, num_digits=2): n=2" in report
assert "Values: 1.00, 0.00, 1.00" in report
assert "Values: 0.50, 0.50" in report
# Test stats string representation
stats = grouped.stats()
stats_report = str(stats)
assert "μ=" in stats_report
assert "σ=" in stats_report
assert "min=" in stats_report
assert "max=" in stats_report
# Test empty scores
empty = GroupedScores(scores=OrderedDict(), total_scores=0)
assert str(empty) == "No scores recorded"
def test_coach_score_logging(tmp_path):
# Create a log file in the temporary directory
log_file = tmp_path / "scores.jsonl"
# Create dataset and coach with logging
config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42)
dataset = ChainSumDataset(config)
coach = Coach(dataset, score_log=log_file)
# Score a few answers
for i in range(3):
item = coach[i]
coach.score_answer(
answer=item["answer"] if i % 2 == 0 else None,
entry=item,
conversation=[
{"role": "user", "content": item["question"]},
{"role": "assistant", "content": item["answer"] if i % 2 == 0 else "I don't know"},
],
)
# Verify log file contents
assert log_file.exists()
# Read and parse log entries
log_entries = [json.loads(line) for line in log_file.open()]
assert len(log_entries) == 3
# Verify log entry structure
for i, entry in enumerate(log_entries):
assert "score" in entry
assert "entry" in entry
assert "metadata" in entry["entry"]
assert "conversation" in entry
assert entry["score"] == (1.0 if i % 2 == 0 else 0.0)
assert len(entry["conversation"]) == 2