Feat/curr adj (#394)

This commit is contained in:
joesharratt1229 2025-04-02 06:39:14 +01:00 committed by GitHub
parent 2c52f33c3a
commit 43c739cb3e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 152390 additions and 453 deletions

View file

@ -1,6 +1,6 @@
"""Experiment class combining dataset, scoreboard and curriculum."""
from typing import Any, Optional
from typing import Any, Literal, Optional
from reasoning_gym.coaching.base_curriculum import CurriculumContext
@ -27,7 +27,8 @@ class Experiment:
entry = dataset[index]
score = dataset.score_answer(answer, entry)
metadata = entry["metadata"]
self.score_board.add_score(score, metadata, conversation)
score_board_metadata = {"difficulty": metadata["difficulty"], "source_dataset": metadata["source_dataset"]}
self.score_board.add_score(dataset_name, score, score_board_metadata, conversation)
return score
@classmethod
@ -97,7 +98,15 @@ class CurriculumExperiment(Experiment):
self.curriculum_config = config
self.context = context
def update_difficulty(self):
def update_difficulty(self, dataset_name: str, method: Literal["increment", "decrement"]):
"""Update difficulty levels based on performance metrics"""
# TODO: Implement difficulty adjustment logic
pass
if method not in ["increment", "decrement"]:
raise ValueError(f"Invalid method: {method}")
if method == "increment":
self.curricula[dataset_name].increment_global_level()
elif method == "decrement":
self.curricula[dataset_name].decrement_global_level()
config = self.curricula[dataset_name].get_global_level()
self.composite.update_dataset_config(dataset_name, config)