mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
239 lines
8.4 KiB
Python
239 lines
8.4 KiB
Python
import json
|
||
import math
|
||
from collections import OrderedDict
|
||
|
||
import pytest
|
||
|
||
from reasoning_gym.arithmetic.chain_sum import ChainSumConfig, ChainSumDataset
|
||
from reasoning_gym.arithmetic.leg_counting import LegCountingConfig
|
||
from reasoning_gym.coaching import Coach, GroupedScores
|
||
from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec
|
||
|
||
|
||
def test_coach_with_chain_sum():
|
||
# Create a small ChainSum dataset
|
||
config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42)
|
||
dataset = ChainSumDataset(config)
|
||
coach = Coach(dataset)
|
||
|
||
# Simulate an agent working on tasks
|
||
for i in range(5):
|
||
item = coach[i]
|
||
|
||
# Simulate some correct and incorrect answers
|
||
if i % 2 == 0:
|
||
# Correct answer
|
||
score = coach.score_answer(
|
||
answer=item["answer"],
|
||
entry=item,
|
||
conversation=[
|
||
{"role": "user", "content": item["question"]},
|
||
{"role": "assistant", "content": item["answer"]},
|
||
],
|
||
)
|
||
assert score == 1.0
|
||
else:
|
||
# Incorrect answer (None)
|
||
score = coach.score_answer(
|
||
answer=None,
|
||
entry=item,
|
||
conversation=[
|
||
{"role": "user", "content": item["question"]},
|
||
{"role": "assistant", "content": "I don't know"},
|
||
],
|
||
)
|
||
assert score == 0.0
|
||
|
||
# Test score aggregation
|
||
aggregated = coach.score_board.aggregate()
|
||
|
||
# Verify we have scores grouped by difficulty parameters
|
||
assert len(aggregated.scores) > 0
|
||
|
||
# Each key should be a tuple of tuples containing difficulty parameters
|
||
for key in aggregated.scores:
|
||
assert isinstance(key, tuple)
|
||
# Each inner tuple should be (param_name, value)
|
||
for param in key:
|
||
assert isinstance(param, tuple)
|
||
assert param[0] in ("num_terms", "num_digits")
|
||
assert isinstance(param[1], int)
|
||
|
||
# Test aggregation with last_n
|
||
last_3 = coach.score_board.aggregate(last_n=3)
|
||
assert len(last_3.scores) > 0
|
||
|
||
# Verify total scores count
|
||
assert last_3.total_scores == 3
|
||
|
||
# Verify conversation tracking
|
||
assert len(coach.score_board.conversations) == 5
|
||
for conv in coach.score_board.conversations:
|
||
assert len(conv) == 2 # user question and assistant response
|
||
assert conv[0]["role"] == "user"
|
||
assert conv[1]["role"] == "assistant"
|
||
|
||
# Test stats calculation
|
||
stats = aggregated.stats()
|
||
|
||
for key, values in stats.scores.items():
|
||
assert isinstance(values, tuple)
|
||
assert len(values) == 5 # (count, mean, std, min, max)
|
||
assert isinstance(values[0], int) # count should be int
|
||
assert all(isinstance(v, float) for v in values[1:]) # stats should be floats
|
||
|
||
# Test stats with empty scores
|
||
empty_stats = GroupedScores(scores=OrderedDict(), total_scores=0).stats()
|
||
assert len(empty_stats.scores) == 0
|
||
|
||
# Test stats with ignore_empty=False
|
||
empty_group = OrderedDict({(("test", 1),): []})
|
||
non_ignoring_stats = GroupedScores(scores=empty_group, total_scores=0).stats(ignore_empty=False)
|
||
assert len(non_ignoring_stats.scores) == 1
|
||
stats_tuple = next(iter(non_ignoring_stats.scores.values()))
|
||
assert stats_tuple[0] == 0 # count should be 0 for empty list
|
||
assert all(math.isnan(v) for v in stats_tuple[1:]) # stats should be NaN
|
||
|
||
print(aggregated)
|
||
print(stats)
|
||
|
||
# Test clear functionality
|
||
coach.score_board.clear()
|
||
assert len(coach.score_board.scores) == 0
|
||
assert len(coach.score_board.metadata) == 0
|
||
assert len(coach.score_board.conversations) == 0
|
||
assert len(coach.score_board.aggregate().scores) == 0
|
||
|
||
|
||
def test_coach_with_composite():
|
||
# Create configs for both datasets
|
||
chain_sum_config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10)
|
||
leg_counting_config = LegCountingConfig(min_animals=2, max_animals=3, size=10)
|
||
|
||
# Create composite config
|
||
composite_config = CompositeConfig(
|
||
size=20,
|
||
seed=42,
|
||
datasets=[
|
||
DatasetSpec(name="chain_sum", weight=1.0, config=chain_sum_config.__dict__),
|
||
DatasetSpec(name="leg_counting", weight=1.0, config=leg_counting_config.__dict__),
|
||
],
|
||
)
|
||
|
||
# Create composite dataset and coach
|
||
dataset = CompositeDataset(composite_config)
|
||
coach = Coach(dataset)
|
||
|
||
# Score some answers
|
||
for i in range(5):
|
||
item = coach[i]
|
||
# Correct answers for even indices
|
||
score = coach.score_answer(
|
||
answer=item["answer"] if i % 2 == 0 else None,
|
||
entry=item,
|
||
conversation=[
|
||
{"role": "user", "content": item["question"]},
|
||
{"role": "assistant", "content": item["answer"] if i % 2 == 0 else "I don't know"},
|
||
],
|
||
)
|
||
assert score in (0.0, 1.0)
|
||
|
||
# Test aggregation
|
||
aggregated = coach.score_board.aggregate()
|
||
assert len(aggregated.scores) > 0
|
||
|
||
# Verify source dataset info is first in keys
|
||
for key in aggregated.scores:
|
||
assert key[0][0] == "source" # First tuple should be ("source", dataset_name)
|
||
assert key[1][0] == "idx" # Second tuple should be ("idx", index)
|
||
|
||
# Test stats
|
||
stats = aggregated.stats()
|
||
for key, values in stats.scores.items():
|
||
assert isinstance(values, tuple)
|
||
assert len(values) == 5 # (count, mean, std, min, max)
|
||
assert isinstance(values[0], int)
|
||
assert all(isinstance(v, float) for v in values[1:])
|
||
|
||
print("\nComposite Dataset Stats:")
|
||
print(stats)
|
||
|
||
# Test config update
|
||
coach.dataset.update_dataset_config("chain_sum", {"min_terms": 4, "max_terms": 5})
|
||
|
||
# Verify the config was updated
|
||
chain_sum_dataset = coach.dataset.datasets["chain_sum"]
|
||
assert chain_sum_dataset.config.min_terms == 4
|
||
assert chain_sum_dataset.config.max_terms == 5
|
||
|
||
# Score some more items to verify new config is in effect
|
||
for i in range(3):
|
||
item = coach[i + 5] # Use different indices
|
||
if "chain_sum" in item["metadata"]["source_dataset"]:
|
||
metadata = item["metadata"]
|
||
assert metadata["difficulty"]["num_terms"] >= 4
|
||
|
||
|
||
def test_grouped_scores_str():
|
||
# Test raw scores string representation
|
||
scores = OrderedDict()
|
||
scores[(("num_terms", 2), ("num_digits", 1))] = [1.0, 0.0, 1.0]
|
||
scores[(("num_terms", 3), ("num_digits", 2))] = [0.5, 0.5]
|
||
grouped = GroupedScores(scores=scores, total_scores=5)
|
||
|
||
report = str(grouped)
|
||
assert "Total scores: 5" in report
|
||
assert "(num_terms=2, num_digits=1): n=3" in report
|
||
assert "(num_terms=3, num_digits=2): n=2" in report
|
||
assert "Values: 1.00, 0.00, 1.00" in report
|
||
assert "Values: 0.50, 0.50" in report
|
||
|
||
# Test stats string representation
|
||
stats = grouped.stats()
|
||
stats_report = str(stats)
|
||
assert "μ=" in stats_report
|
||
assert "σ=" in stats_report
|
||
assert "min=" in stats_report
|
||
assert "max=" in stats_report
|
||
|
||
# Test empty scores
|
||
empty = GroupedScores(scores=OrderedDict(), total_scores=0)
|
||
assert str(empty) == "No scores recorded"
|
||
|
||
|
||
def test_coach_score_logging(tmp_path):
|
||
# Create a log file in the temporary directory
|
||
log_file = tmp_path / "scores.jsonl"
|
||
|
||
# Create dataset and coach with logging
|
||
config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42)
|
||
dataset = ChainSumDataset(config)
|
||
coach = Coach(dataset, score_log=log_file)
|
||
|
||
# Score a few answers
|
||
for i in range(3):
|
||
item = coach[i]
|
||
coach.score_answer(
|
||
answer=item["answer"] if i % 2 == 0 else None,
|
||
entry=item,
|
||
conversation=[
|
||
{"role": "user", "content": item["question"]},
|
||
{"role": "assistant", "content": item["answer"] if i % 2 == 0 else "I don't know"},
|
||
],
|
||
)
|
||
|
||
# Verify log file contents
|
||
assert log_file.exists()
|
||
|
||
# Read and parse log entries
|
||
log_entries = [json.loads(line) for line in log_file.open()]
|
||
assert len(log_entries) == 3
|
||
|
||
# Verify log entry structure
|
||
for i, entry in enumerate(log_entries):
|
||
assert "score" in entry
|
||
assert "entry" in entry
|
||
assert "metadata" in entry["entry"]
|
||
assert "conversation" in entry
|
||
assert entry["score"] == (1.0 if i % 2 == 0 else 0.0)
|
||
assert len(entry["conversation"]) == 2
|