diff --git a/reasoning_gym/algorithmic/sentence_reordering.py b/reasoning_gym/algorithmic/sentence_reordering.py index acb7cd23..57f19d6e 100644 --- a/reasoning_gym/algorithmic/sentence_reordering.py +++ b/reasoning_gym/algorithmic/sentence_reordering.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass from random import Random -from typing import Optional +from typing import Any, Dict, Optional from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset @@ -92,5 +92,26 @@ class SentenceReorderingDataset(ProceduralDataset): "metadata": {"word_count": word_count}, } + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: + reward = 0 + expected_answer = entry["answer"] + if answer is not None: + try: + if expected_answer == answer: + return 1.0 + goal_words = expected_answer.split() + answer_words = answer.split() + if len(goal_words) == len(answer_words): + credit = [ + 1 if goal_word.lower() == answer_word.lower() else 0 + for goal_word, answer_word in zip(goal_words, answer_words) + ] + reward = sum(credit) / len(credit) + else: + reward = 0.05 + except: + reward = 0.01 + return reward + register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig) diff --git a/reasoning_gym/algorithmic/spell_backward.py b/reasoning_gym/algorithmic/spell_backward.py index 59b163ee..d1837521 100644 --- a/reasoning_gym/algorithmic/spell_backward.py +++ b/reasoning_gym/algorithmic/spell_backward.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass from random import Random -from typing import Optional +from typing import Any, Dict, Optional from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset @@ -49,5 +49,18 @@ class SpellBackwardDataset(ProceduralDataset): "metadata": {"word": word, "word_len": len(word)}, } + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: + reward = 0 + expected_answer = entry["answer"] + if answer is not None: + try: + if expected_answer.lower() == answer.lower(): + reward = 1.0 + else: + reward = 0.05 + except: + reward = 0.01 + return reward + register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig) diff --git a/reasoning_gym/algorithmic/word_ladder.py b/reasoning_gym/algorithmic/word_ladder.py index 64c65326..3be99138 100644 --- a/reasoning_gym/algorithmic/word_ladder.py +++ b/reasoning_gym/algorithmic/word_ladder.py @@ -8,6 +8,10 @@ from typing import Dict, List, Optional, Set, Tuple from ..data import get_data_file_path from ..factory import ProceduralDataset, register_dataset +QUESTION_TEMPLATE = """Transform the word ladder '{start}' to '{end}' by changing one letter at a time. + Provide your answer as a comma-separated sequence of uppercase letters without spaces. + Each step must be a valid English word.""" + @dataclass class WordLadderConfig: @@ -211,7 +215,7 @@ class WordLadderDataset(ProceduralDataset): raise IndexError(f"Dataset exhausted at index {idx}. {str(e)}") return { - "question": f"Transform the word ladder '{start}' to '{end}' by changing one letter at a time.", + "question": QUESTION_TEMPLATE.format(start=start, end=end), "answer": ",".join(path), "metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)}, } diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 156314d9..c72ff143 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -223,12 +223,15 @@ class BasicArithmeticDataset(ProceduralDataset): return expression, result def _format_question(self, rng: Random, expression: str) -> str: - """Format the expression according to config style""" + """Format the expression with clear answer positioning""" + answer_instruction = "Put your final answer after '=' without additional text." + if self.config.format_style == "simple": - return f"{expression} =" + return f"{answer_instruction} Calculate {expression} =" else: - templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"] - return rng.choice(templates).format(expression) + templates = ["What is {0} =", "Solve {0}=", "Compute {0} =", "Evaluate: {0} ="] + template = rng.choice(templates).format(expression) + return f"{answer_instruction} {template}" # Register the dataset diff --git a/tests/test_basic_arithmetic.py b/tests/test_basic_arithmetic.py index 3d3d08b5..406e4617 100644 --- a/tests/test_basic_arithmetic.py +++ b/tests/test_basic_arithmetic.py @@ -68,7 +68,7 @@ def test_arithmetic_dataset_format_styles(): config.format_style = "natural" dataset = BasicArithmeticDataset(config) - assert all("=" not in item["question"] for item in dataset) + assert all("=" in item["question"] for item in dataset) def test_arithmetic_dataset_iteration(): diff --git a/tests/test_sentence_reordering.py b/tests/test_sentence_reordering.py index 9ed5b4da..9348ec04 100644 --- a/tests/test_sentence_reordering.py +++ b/tests/test_sentence_reordering.py @@ -37,6 +37,7 @@ def test_getitem(dataset, config): assert "metadata" in item assert item["metadata"]["word_count"] >= config.min_words_in_sentence assert item["metadata"]["word_count"] <= config.max_words_in_sentence + assert len(item["answer"].split()) == item["metadata"]["word_count"] def test_key_error_in_getitem(dataset):