diff --git a/training/train_grpo.py b/training/train_grpo.py index 9da5d16d..48a6a4d0 100644 --- a/training/train_grpo.py +++ b/training/train_grpo.py @@ -41,8 +41,8 @@ def prepare_datasets(config, tokenizer) -> tuple[ReasoningGymDataset, ReasoningG ] train_data_source = reasoning_gym.create_dataset("composite", seed=1, size=dataset_size, datasets=dataset_specs) val_data_source = reasoning_gym.create_dataset("composite", seed=2, size=dataset_size, datasets=dataset_specs) - train_dataset = make_dataset(tokenizer, train_data_source, "composite", developer_prompt) - val_dataset = make_dataset(tokenizer, val_data_source, "composite", developer_prompt) + train_dataset = make_dataset(tokenizer, train_data_source, developer_prompt) + val_dataset = make_dataset(tokenizer, val_data_source, developer_prompt) return train_dataset, val_dataset diff --git a/training/utils/datasets.py b/training/utils/datasets.py index 973ae2f0..41aeb591 100644 --- a/training/utils/datasets.py +++ b/training/utils/datasets.py @@ -105,7 +105,6 @@ class ReasoningGymDataset(Dataset): def make_dataset( tokenizer, data_source: Experiment | ProceduralDataset, - dataset_name: str, developer_prompt: str, ) -> ReasoningGymDataset: """