diff --git a/reasoning_gym/algebra/simple_equations.py b/reasoning_gym/algebra/simple_equations.py index 5a85fcb5..c8de94b8 100644 --- a/reasoning_gym/algebra/simple_equations.py +++ b/reasoning_gym/algebra/simple_equations.py @@ -3,8 +3,7 @@ import string from dataclasses import dataclass from typing import Optional, Tuple -import sympy -from sympy import Eq, Symbol, solve +from sympy import Symbol from ..factory import ProceduralDataset, register_dataset diff --git a/reasoning_gym/algorithmic/caesar_cipher.py b/reasoning_gym/algorithmic/caesar_cipher.py index 01b8f4ed..2afc6d1f 100644 --- a/reasoning_gym/algorithmic/caesar_cipher.py +++ b/reasoning_gym/algorithmic/caesar_cipher.py @@ -1,10 +1,8 @@ """Caesar cipher task generator""" -import re from dataclasses import dataclass from random import Random -from string import ascii_uppercase -from typing import List, Optional +from typing import Optional from reasoning_gym.data import read_data_file diff --git a/reasoning_gym/algorithmic/letter_counting.py b/reasoning_gym/algorithmic/letter_counting.py index 8f2590dd..2ed65737 100644 --- a/reasoning_gym/algorithmic/letter_counting.py +++ b/reasoning_gym/algorithmic/letter_counting.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass from random import Random -from typing import List, Optional +from typing import Optional from reasoning_gym.data import read_data_file diff --git a/reasoning_gym/algorithmic/sentence_reordering.py b/reasoning_gym/algorithmic/sentence_reordering.py index bc509b8c..acb7cd23 100644 --- a/reasoning_gym/algorithmic/sentence_reordering.py +++ b/reasoning_gym/algorithmic/sentence_reordering.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass from random import Random -from typing import List, Optional +from typing import Optional from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset diff --git a/reasoning_gym/algorithmic/word_sequence_reversal.py b/reasoning_gym/algorithmic/word_sequence_reversal.py index 5ad94ea7..c678e173 100644 --- a/reasoning_gym/algorithmic/word_sequence_reversal.py +++ b/reasoning_gym/algorithmic/word_sequence_reversal.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass from random import Random -from typing import List, Optional +from typing import Optional from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset diff --git a/reasoning_gym/arithmetic/__init__.py b/reasoning_gym/arithmetic/__init__.py index bb56836d..f5e3eb1f 100644 --- a/reasoning_gym/arithmetic/__init__.py +++ b/reasoning_gym/arithmetic/__init__.py @@ -1,10 +1,5 @@ """ Arithmetic tasks for training reasoning capabilities: -- Basic arithmetic -- Chain sums -- Word problems -- Leg counting -- Time intervals """ from .basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig @@ -21,13 +16,10 @@ from .time_intervals import TimeIntervalsConfig, TimeIntervalsDataset __all__ = [ "BasicArithmeticDataset", "BasicArithmeticDatasetConfig", - "basic_arithmetic_dataset", "ChainSum", "ChainSumConfig", "CalendarArithmeticConfig", "CalendarArithmeticDataset", - "Weekday", - "CalendarTask", "FractionSimplificationConfig", "FractionSimplificationDataset", "GCDConfig", diff --git a/reasoning_gym/arithmetic/calendar_arithmetic.py b/reasoning_gym/arithmetic/calendar_arithmetic.py index 78c42df8..bf12211c 100644 --- a/reasoning_gym/arithmetic/calendar_arithmetic.py +++ b/reasoning_gym/arithmetic/calendar_arithmetic.py @@ -3,7 +3,7 @@ import math import random from dataclasses import dataclass from datetime import date, timedelta -from enum import Enum, auto +from enum import Enum, StrEnum, auto from typing import Any, Dict, List, Optional, Tuple from ..factory import ProceduralDataset, register_dataset @@ -38,7 +38,7 @@ class Weekday(Enum): return self.name.capitalize() -class CalendarTask(Enum): +class CalendarTask(StrEnum): WEEKDAY_OFFSET = "weekday_offset" WEEKDAY_OF_DATE = "weekday_of_date" WEEKDAY_OF_DATE_FROM_FIRST_DATE = "weekday_of_date_from_first_day" diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 24afda1a..30dcb0c4 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -65,8 +65,10 @@ class ChainSum(ProceduralDataset): "question": f"{expression} =", "answer": str(result), "metadata": { - "num_terms": num_terms, - "num_digits": num_digits, + "difficulty": { + "num_terms": num_terms, + "num_digits": num_digits, + }, "expression": expression, }, } diff --git a/reasoning_gym/arithmetic/leg_counting.py b/reasoning_gym/arithmetic/leg_counting.py index de950631..58b62b1a 100644 --- a/reasoning_gym/arithmetic/leg_counting.py +++ b/reasoning_gym/arithmetic/leg_counting.py @@ -111,7 +111,13 @@ class LegCountingDataset(ProceduralDataset): return { "question": question, "answer": str(total_legs), - "metadata": {"animals": animals, "total_legs": total_legs}, + "metadata": { + "difficulty": { + "num_animals": len(animals), + }, + "animals": animals, + "total_legs": total_legs, + }, } diff --git a/reasoning_gym/arithmetic/prime_factorization.py b/reasoning_gym/arithmetic/prime_factorization.py index c51f90ee..d8c9d2af 100644 --- a/reasoning_gym/arithmetic/prime_factorization.py +++ b/reasoning_gym/arithmetic/prime_factorization.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from random import Random -from typing import List, Optional, Tuple +from typing import List, Optional from ..factory import ProceduralDataset, register_dataset diff --git a/reasoning_gym/coaching.py b/reasoning_gym/coaching.py new file mode 100644 index 00000000..ad14077a --- /dev/null +++ b/reasoning_gym/coaching.py @@ -0,0 +1,257 @@ +"""Coaching module for difficulty adjustment and score tracking""" + +import json +import math +from collections import OrderedDict +from dataclasses import dataclass, field +from pathlib import Path +from statistics import mean, stdev +from typing import Any, Dict, List, Optional, Tuple, Union + +from .dataset import ProceduralDataset + + +@dataclass +class ScoreStats: + """Container for score statistics with mean, std, min, max""" + + scores: OrderedDict[Tuple[Tuple[str, Any], ...], Tuple[int, float, float, float, float]] + + def __str__(self) -> str: + """Create a formatted report of the statistics + + Returns: + Multi-line string with statistics for each group + """ + if not self.scores: + return "No scores recorded" + + lines = [] + + for key, values in self.scores.items(): + params = ", ".join(f"{k}={v}" for k, v in key) + lines.append( + f"({params}): n={values[0]}, μ={values[1]:.3f}, σ={values[2]:.3f}, min={values[3]:.3f}, max={values[4]:.3f}" + ) + + return "\n".join(lines) + + +@dataclass +class GroupedScores: + """Container for grouped scores with total count""" + + scores: OrderedDict[Tuple[Tuple[str, Any], ...], List[float]] + total_scores: int + + def __str__(self) -> str: + """Create a formatted report of the scores + + Returns: + Multi-line string with score information for each difficulty group + """ + if not self.scores: + return "No scores recorded" + + lines = [] + lines.append(f"Total scores: {self.total_scores}") + lines.append("") + + for key, values in self.scores.items(): + # Format the parameter combinations + params = ", ".join(f"{k}={v}" for k, v in key) + stats = ( + len(values), + mean(values) if values else 0.0, + stdev(values) if len(values) > 1 else 0.0, + min(values) if values else 0.0, + max(values) if values else 0.0, + ) + lines.append( + f"({params}): n={stats[0]}, μ={stats[1]:.3f}, σ={stats[2]:.3f}, min={stats[3]:.3f}, max={stats[4]:.3f}" + ) + # Format score list, showing only last 100 if more + score_strs = [f"{x:.2f}" for x in values[-100:]] + if len(values) > 100: + score_strs.insert(0, "..") + lines.append(f" Values: {', '.join(score_strs)}") + + return "\n".join(lines) + + def stats(self, ignore_empty: bool = True) -> ScoreStats: + """Calculate statistics for each group of scores + + Args: + ignore_empty: If True, skip empty score lists + If False, use NaN values for empty lists + + Returns: + ScoreStats object containing statistics for each group + """ + result = OrderedDict() + + for key, values in self.scores.items(): + if not values and ignore_empty: + continue + + if not values: + # Empty list and not ignoring - use NaN + result[key] = (0, math.nan, math.nan, math.nan, math.nan) + else: + # Calculate stats as tuple + result[key] = ( + len(values), + mean(values), + stdev(values) if len(values) > 1 else 0.0, + min(values), + max(values), + ) + + return ScoreStats(scores=result) + + +@dataclass +class ScoreBoard: + """Tracks scores and metadata for coaching sessions""" + + scores: List[float] = field(default_factory=list) + metadata: List[Dict[str, Any]] = field(default_factory=list) + conversations: List[Optional[List[Dict]]] = field(default_factory=list) + + def add_score(self, score: float, metadata: Dict[str, Any], conversation: Optional[List[Dict]] = None) -> None: + """Add a new score entry with associated metadata and optional conversation + + Args: + score: The score achieved (typically 0.0-1.0) + metadata: Dictionary of metadata about the task/attempt + conversation: Optional list of conversation turns as dicts + """ + self.scores.append(score) + self.metadata.append(metadata) + self.conversations.append(conversation) + + def clear(self) -> None: + """Clear all stored scores, metadata and conversations""" + self.scores.clear() + self.metadata.clear() + self.conversations.clear() + + def __len__(self) -> int: + """Return the number of stored scores""" + return len(self.scores) + + def _metadata_to_key(self, metadata: Dict[str, Any]) -> Tuple[Tuple[str, Any], ...]: + """Convert metadata dict to tuple of key-value pairs, sorted by key + + If source_dataset and source_index are present in metadata, they will be + placed first in the tuple as ("source", dataset) and ("idx", index). + """ + # Start with empty list + key_items = [] + + # Add source info first if present + if "source_dataset" in metadata and "source_index" in metadata: + key_items.extend([("source", metadata["source_dataset"]), ("idx", metadata["source_index"])]) + + # Add difficulty parameters or other metadata + if "difficulty" in metadata: + # Use only difficulty parameters + items = metadata["difficulty"].items() + else: + # Use all metadata except source info + items = ((k, v) for k, v in metadata.items() if k not in ("source_dataset", "source_index")) + + # Add remaining items in sorted order + key_items.extend(sorted((str(k), v) for k, v in items)) + + return tuple(key_items) + + def aggregate(self, last_n: Optional[int] = None) -> GroupedScores: + """Aggregate scores by difficulty parameters or full metadata if no difficulty present + + Args: + last_n: Optional number of most recent entries to consider + If None, use all entries + + Returns: + OrderedDict mapping difficulty parameter combinations to lists of scores + Keys are tuples of (param_name, value) pairs, sorted by param_name + """ + if not self.scores: + return GroupedScores(scores=OrderedDict(), total_scores=0) + + # Determine start index for iteration + start_idx = max(0, len(self.scores) - last_n) if last_n is not None else 0 + + # Group scores by difficulty parameters without creating intermediate lists + result = OrderedDict() + for i in range(start_idx, len(self.scores)): + key = self._metadata_to_key(self.metadata[i]) + if key not in result: + result[key] = [] + result[key].append(self.scores[i]) + + # Count total scores + total_scores = sum(len(scores) for scores in result.values()) + + return GroupedScores(scores=result, total_scores=total_scores) + + +class Coach(ProceduralDataset): + """A dataset wrapper that tracks performance and adjusts difficulty + + The Coach wraps a ProceduralDataset (typically a CompositeDataset) and: + 1. Tracks scores and metadata in a ScoreBoard + 2. Adjusts difficulty based on performance (to be implemented) + """ + + def __init__(self, dataset: ProceduralDataset, score_log: Optional[Union[str, Path]] = None): + """Initialize with inner dataset + + Args: + dataset: The ProceduralDataset to wrap + score_log: Optional path to jsonl file for logging scores + """ + super().__init__(config=dataset.config, seed=dataset.seed, size=dataset.size) + self.dataset = dataset + self.score_board = ScoreBoard() + self.score_log = Path(score_log) if score_log else None + + def __getitem__(self, idx: int) -> dict: + """Forward item generation to inner dataset""" + return self.dataset[idx] + + def score_answer( + self, answer: Optional[str], entry: Dict[str, Any], conversation: Optional[List[Dict]] = None + ) -> float: + """Score answer and track results + + Args: + answer: The answer to score + entry: The task entry containing question/answer/metadata + conversation: Optional conversation history as list of message dicts + + Returns: + float: Score between 0.0 and 1.0 + """ + # Get score from inner dataset + score = self.dataset.score_answer(answer, entry) + + # Track score and metadata + self.score_board.add_score(score=score, metadata=entry["metadata"], conversation=conversation) + + # Log score if logging is enabled + if self.score_log is not None: + log_entry = {"score": score, "answer": answer, "entry": entry, "conversation": conversation} + with self.score_log.open("a") as f: + json.dump(log_entry, f) + f.write("\n") + + return score + + def update_difficulty(self) -> None: + """Update difficulty based on recent performance + + To be implemented in future versions. + """ + pass # Placeholder for future difficulty adjustment logic diff --git a/reasoning_gym/code/__init__.py b/reasoning_gym/code/__init__.py index d250ad6e..73e7fd60 100644 --- a/reasoning_gym/code/__init__.py +++ b/reasoning_gym/code/__init__.py @@ -1,8 +1,5 @@ """ -Cognition tasks for training reasoning capabilities: -- Code Analysis -- Code Interpretation -- Code Execution +Code reasing tasks """ from .bf import BFConfig, BFDataset diff --git a/reasoning_gym/cognition/__init__.py b/reasoning_gym/cognition/__init__.py index 38baf31b..751c9ab8 100644 --- a/reasoning_gym/cognition/__init__.py +++ b/reasoning_gym/cognition/__init__.py @@ -1,9 +1,5 @@ """ -Cognition tasks for training reasoning capabilities: -- Pattern recognition -- Sequence completion -- Logical reasoning -- Working memory +Cognition tasks for training reasoning capabilities. """ from .arc_1d import Arc1DConfig, Arc1DDataset diff --git a/reasoning_gym/composite.py b/reasoning_gym/composite.py index 7ea581b2..2050ddd1 100644 --- a/reasoning_gym/composite.py +++ b/reasoning_gym/composite.py @@ -100,6 +100,34 @@ class CompositeDataset(ProceduralDataset): return item + def update_dataset_config(self, dataset_name: str, config_updates: Dict[str, Any]) -> None: + """Update configuration of a specific dataset + + Args: + dataset_name: Name of the dataset to update + config_updates: Dictionary of configuration parameters to update + + Raises: + KeyError: If dataset_name is not found + AttributeError: If config parameter doesn't exist + """ + if dataset_name not in self.datasets: + raise KeyError(f"Dataset '{dataset_name}' not found") + + dataset = self.datasets[dataset_name] + + # Create new config with updates + new_config = dataset.config.__class__(**vars(dataset.config)) + for key, value in config_updates.items(): + setattr(new_config, key, value) + + # Validate new config + new_config.validate() + + # Create new dataset instance with updated config + dataset_cls = dataset.__class__ + self.datasets[dataset_name] = dataset_cls(new_config) + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: """Forward scoring to appropriate dataset""" dataset_name = entry["metadata"]["source_dataset"] diff --git a/reasoning_gym/data/__init__.py b/reasoning_gym/data/__init__.py index d0c4f943..a88fda73 100644 --- a/reasoning_gym/data/__init__.py +++ b/reasoning_gym/data/__init__.py @@ -2,7 +2,6 @@ from importlib import resources from pathlib import Path -from typing import Union def get_data_file_path(filename: str) -> Path: diff --git a/reasoning_gym/logic/__init__.py b/reasoning_gym/logic/__init__.py index 96b24277..dfa1c7ad 100644 --- a/reasoning_gym/logic/__init__.py +++ b/reasoning_gym/logic/__init__.py @@ -1,9 +1,5 @@ """ -Logic tasks for training reasoning capabilities: -- Propositional logic -- Predicate logic -- Set theory -- Syllogisms +Logic tasks for training reasoning capabilities. """ from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py index 0c864cc4..7130a11d 100644 --- a/reasoning_gym/logic/aiw.py +++ b/reasoning_gym/logic/aiw.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from enum import Enum +from enum import StrEnum from random import Random from string import Template from typing import List, Optional @@ -7,7 +7,7 @@ from typing import List, Optional from ..factory import ProceduralDataset, register_dataset -class TaskType(Enum): +class TaskType(StrEnum): """Defines the type of task for the Alice in Wonderland dataset.""" SIBLINGS = "siblings" diff --git a/tests/test_coaching.py b/tests/test_coaching.py new file mode 100644 index 00000000..83b56768 --- /dev/null +++ b/tests/test_coaching.py @@ -0,0 +1,240 @@ +import json +import math +from collections import OrderedDict +from pathlib import Path + +import pytest + +from reasoning_gym.arithmetic.chain_sum import ChainSum, ChainSumConfig +from reasoning_gym.arithmetic.leg_counting import LegCountingConfig +from reasoning_gym.coaching import Coach, GroupedScores +from reasoning_gym.composite import CompositeConfig, CompositeDataset, DatasetSpec + + +def test_coach_with_chain_sum(): + # Create a small ChainSum dataset + config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) + dataset = ChainSum(config) + coach = Coach(dataset) + + # Simulate an agent working on tasks + for i in range(5): + item = coach[i] + + # Simulate some correct and incorrect answers + if i % 2 == 0: + # Correct answer + score = coach.score_answer( + answer=item["answer"], + entry=item, + conversation=[ + {"role": "user", "content": item["question"]}, + {"role": "assistant", "content": item["answer"]}, + ], + ) + assert score == 1.0 + else: + # Incorrect answer (None) + score = coach.score_answer( + answer=None, + entry=item, + conversation=[ + {"role": "user", "content": item["question"]}, + {"role": "assistant", "content": "I don't know"}, + ], + ) + assert score == 0.0 + + # Test score aggregation + aggregated = coach.score_board.aggregate() + + # Verify we have scores grouped by difficulty parameters + assert len(aggregated.scores) > 0 + + # Each key should be a tuple of tuples containing difficulty parameters + for key in aggregated.scores: + assert isinstance(key, tuple) + # Each inner tuple should be (param_name, value) + for param in key: + assert isinstance(param, tuple) + assert param[0] in ("num_terms", "num_digits") + assert isinstance(param[1], int) + + # Test aggregation with last_n + last_3 = coach.score_board.aggregate(last_n=3) + assert len(last_3.scores) > 0 + + # Verify total scores count + assert last_3.total_scores == 3 + + # Verify conversation tracking + assert len(coach.score_board.conversations) == 5 + for conv in coach.score_board.conversations: + assert len(conv) == 2 # user question and assistant response + assert conv[0]["role"] == "user" + assert conv[1]["role"] == "assistant" + + # Test stats calculation + stats = aggregated.stats() + + for key, values in stats.scores.items(): + assert isinstance(values, tuple) + assert len(values) == 5 # (count, mean, std, min, max) + assert isinstance(values[0], int) # count should be int + assert all(isinstance(v, float) for v in values[1:]) # stats should be floats + + # Test stats with empty scores + empty_stats = GroupedScores(scores=OrderedDict(), total_scores=0).stats() + assert len(empty_stats.scores) == 0 + + # Test stats with ignore_empty=False + empty_group = OrderedDict({(("test", 1),): []}) + non_ignoring_stats = GroupedScores(scores=empty_group, total_scores=0).stats(ignore_empty=False) + assert len(non_ignoring_stats.scores) == 1 + stats_tuple = next(iter(non_ignoring_stats.scores.values())) + assert stats_tuple[0] == 0 # count should be 0 for empty list + assert all(math.isnan(v) for v in stats_tuple[1:]) # stats should be NaN + + print(aggregated) + print(stats) + + # Test clear functionality + coach.score_board.clear() + assert len(coach.score_board.scores) == 0 + assert len(coach.score_board.metadata) == 0 + assert len(coach.score_board.conversations) == 0 + assert len(coach.score_board.aggregate().scores) == 0 + + +def test_coach_with_composite(): + # Create configs for both datasets + chain_sum_config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10) + leg_counting_config = LegCountingConfig(min_animals=2, max_animals=3, size=10) + + # Create composite config + composite_config = CompositeConfig( + size=20, + seed=42, + datasets=[ + DatasetSpec(name="chain_sum", weight=1.0, config=chain_sum_config.__dict__), + DatasetSpec(name="leg_counting", weight=1.0, config=leg_counting_config.__dict__), + ], + ) + + # Create composite dataset and coach + dataset = CompositeDataset(composite_config) + coach = Coach(dataset) + + # Score some answers + for i in range(5): + item = coach[i] + # Correct answers for even indices + score = coach.score_answer( + answer=item["answer"] if i % 2 == 0 else None, + entry=item, + conversation=[ + {"role": "user", "content": item["question"]}, + {"role": "assistant", "content": item["answer"] if i % 2 == 0 else "I don't know"}, + ], + ) + assert score in (0.0, 1.0) + + # Test aggregation + aggregated = coach.score_board.aggregate() + assert len(aggregated.scores) > 0 + + # Verify source dataset info is first in keys + for key in aggregated.scores: + assert key[0][0] == "source" # First tuple should be ("source", dataset_name) + assert key[1][0] == "idx" # Second tuple should be ("idx", index) + + # Test stats + stats = aggregated.stats() + for key, values in stats.scores.items(): + assert isinstance(values, tuple) + assert len(values) == 5 # (count, mean, std, min, max) + assert isinstance(values[0], int) + assert all(isinstance(v, float) for v in values[1:]) + + print("\nComposite Dataset Stats:") + print(stats) + + # Test config update + coach.dataset.update_dataset_config("chain_sum", {"min_terms": 4, "max_terms": 5}) + + # Verify the config was updated + chain_sum_dataset = coach.dataset.datasets["chain_sum"] + assert chain_sum_dataset.config.min_terms == 4 + assert chain_sum_dataset.config.max_terms == 5 + + # Score some more items to verify new config is in effect + for i in range(3): + item = coach[i + 5] # Use different indices + if "chain_sum" in item["metadata"]["source_dataset"]: + metadata = item["metadata"] + assert metadata["difficulty"]["num_terms"] >= 4 + + +def test_grouped_scores_str(): + # Test raw scores string representation + scores = OrderedDict() + scores[(("num_terms", 2), ("num_digits", 1))] = [1.0, 0.0, 1.0] + scores[(("num_terms", 3), ("num_digits", 2))] = [0.5, 0.5] + grouped = GroupedScores(scores=scores, total_scores=5) + + report = str(grouped) + assert "Total scores: 5" in report + assert "(num_terms=2, num_digits=1): n=3" in report + assert "(num_terms=3, num_digits=2): n=2" in report + assert "Values: 1.00, 0.00, 1.00" in report + assert "Values: 0.50, 0.50" in report + + # Test stats string representation + stats = grouped.stats() + stats_report = str(stats) + assert "μ=" in stats_report + assert "σ=" in stats_report + assert "min=" in stats_report + assert "max=" in stats_report + + # Test empty scores + empty = GroupedScores(scores=OrderedDict(), total_scores=0) + assert str(empty) == "No scores recorded" + + +def test_coach_score_logging(tmp_path): + # Create a log file in the temporary directory + log_file = tmp_path / "scores.jsonl" + + # Create dataset and coach with logging + config = ChainSumConfig(min_terms=2, max_terms=3, min_digits=1, max_digits=2, size=10, seed=42) + dataset = ChainSum(config) + coach = Coach(dataset, score_log=log_file) + + # Score a few answers + for i in range(3): + item = coach[i] + coach.score_answer( + answer=item["answer"] if i % 2 == 0 else None, + entry=item, + conversation=[ + {"role": "user", "content": item["question"]}, + {"role": "assistant", "content": item["answer"] if i % 2 == 0 else "I don't know"}, + ], + ) + + # Verify log file contents + assert log_file.exists() + + # Read and parse log entries + log_entries = [json.loads(line) for line in log_file.open()] + assert len(log_entries) == 3 + + # Verify log entry structure + for i, entry in enumerate(log_entries): + assert "score" in entry + assert "entry" in entry + assert "metadata" in entry["entry"] + assert "conversation" in entry + assert entry["score"] == (1.0 if i % 2 == 0 else 0.0) + assert len(entry["conversation"]) == 2