From d9b08a579ea4421da2ebba1004426d7a46c17061 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 16 Feb 2025 09:04:17 +0000 Subject: [PATCH 1/8] reformatted word ladder question template --- reasoning_gym/algorithmic/word_ladder.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/reasoning_gym/algorithmic/word_ladder.py b/reasoning_gym/algorithmic/word_ladder.py index 64c65326..40e36291 100644 --- a/reasoning_gym/algorithmic/word_ladder.py +++ b/reasoning_gym/algorithmic/word_ladder.py @@ -9,6 +9,11 @@ from ..data import get_data_file_path from ..factory import ProceduralDataset, register_dataset +QUESTION_TEMPLATE = """Transform the word ladder '{start}' to '{end}' by changing one letter at a time. + Provide your answer as a comma-separated sequence of uppercase letters without spaces. + Each step must be a valid English word.""" + + @dataclass class WordLadderConfig: """Configuration for word ladder task generation""" @@ -211,7 +216,7 @@ class WordLadderDataset(ProceduralDataset): raise IndexError(f"Dataset exhausted at index {idx}. {str(e)}") return { - "question": f"Transform the word ladder '{start}' to '{end}' by changing one letter at a time.", + "question": QUESTION_TEMPLATE.format(start=start, end=end), "answer": ",".join(path), "metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)}, } From c28688cb9691d63ec3936894df1ef1823a1b8817 Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Sun, 16 Feb 2025 09:07:56 +0000 Subject: [PATCH 2/8] reformatted basic airth question template --- reasoning_gym/arithmetic/basic_arithmetic.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index 156314d9..bab95b63 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -224,11 +224,13 @@ class BasicArithmeticDataset(ProceduralDataset): def _format_question(self, rng: Random, expression: str) -> str: """Format the expression according to config style""" + base_question = "Return only the answer to the following question: {question} =" if self.config.format_style == "simple": - return f"{expression} =" + return base_question.format(question=expression) else: templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"] - return rng.choice(templates).format(expression) + question = rng.choice(templates).format(expression) + return base_question.format(question=question) # Register the dataset From a59e4cc9188e4c1577989af25facf21b61bbc06e Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Sun, 16 Feb 2025 09:27:21 +0000 Subject: [PATCH 3/8] reformatted prompt --- reasoning_gym/arithmetic/basic_arithmetic.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index bab95b63..a65ea295 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -223,14 +223,19 @@ class BasicArithmeticDataset(ProceduralDataset): return expression, result def _format_question(self, rng: Random, expression: str) -> str: - """Format the expression according to config style""" - base_question = "Return only the answer to the following question: {question} =" + """Format the expression with clear answer positioning""" + answer_instruction = "Put your final answer after '=' without additional text." + if self.config.format_style == "simple": - return base_question.format(question=expression) + return f"Calculate {expression} =" else: - templates = ["What is {0}?", "Calculate {0}", "Solve {0}", "Evaluate the expression: {0}"] - question = rng.choice(templates).format(expression) - return base_question.format(question=question) + templates = [ + "What is {0}? =", + "Solve {0} and write answer after =", + "Compute {0} =", + "Evaluate: {0} =" + ] + return rng.choice(templates).format(expression) + f" {answer_instruction}" # Register the dataset From 569517664fb25a2e66573d2533c6aa7671749944 Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Sun, 16 Feb 2025 12:01:54 +0000 Subject: [PATCH 4/8] corrected failing airthmetic test --- reasoning_gym/arithmetic/basic_arithmetic.py | 9 +++++---- tests/test_basic_arithmetic.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index a65ea295..efe9d465 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -227,15 +227,16 @@ class BasicArithmeticDataset(ProceduralDataset): answer_instruction = "Put your final answer after '=' without additional text." if self.config.format_style == "simple": - return f"Calculate {expression} =" + return f"{answer_instruction} Calculate {expression} =" else: templates = [ - "What is {0}? =", - "Solve {0} and write answer after =", + "What is {0} =", + "Solve {0}=", "Compute {0} =", "Evaluate: {0} =" ] - return rng.choice(templates).format(expression) + f" {answer_instruction}" + template = rng.choice(templates).format(expression) + return f"{answer_instruction} {template}" # Register the dataset diff --git a/tests/test_basic_arithmetic.py b/tests/test_basic_arithmetic.py index 3d3d08b5..406e4617 100644 --- a/tests/test_basic_arithmetic.py +++ b/tests/test_basic_arithmetic.py @@ -68,7 +68,7 @@ def test_arithmetic_dataset_format_styles(): config.format_style = "natural" dataset = BasicArithmeticDataset(config) - assert all("=" not in item["question"] for item in dataset) + assert all("=" in item["question"] for item in dataset) def test_arithmetic_dataset_iteration(): From 1c930a5e236917db0651adef15747ba06f201003 Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Sun, 16 Feb 2025 13:10:03 +0000 Subject: [PATCH 5/8] added custom score answer func --- .../algorithmic/sentence_reordering.py | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/reasoning_gym/algorithmic/sentence_reordering.py b/reasoning_gym/algorithmic/sentence_reordering.py index acb7cd23..069dda7c 100644 --- a/reasoning_gym/algorithmic/sentence_reordering.py +++ b/reasoning_gym/algorithmic/sentence_reordering.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass from random import Random -from typing import Optional +from typing import Any, Dict, Optional from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset @@ -92,5 +92,33 @@ class SentenceReorderingDataset(ProceduralDataset): "metadata": {"word_count": word_count}, } + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: + reward = 0 + expected_answer = entry["answer"] + if answer is not None: + try: + if expected_answer == answer: + return 1.0 + goal_words = expected_answer.split() + answer_words = answer.split() + if len(goal_words) == len(answer_words): + credit = [1 if goal_word.lower() == answer_word.lower() else 0 for goal_word, answer_word in zip(goal_words, answer_words)] + reward = sum(credit) / len(credit) + else: + reward = 0.05 + except: + reward = 0.01 + return reward + + + + + + + + + + + register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig) From 35fe482c4d785012959fdb23162d5e47060e6ef1 Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Sun, 16 Feb 2025 13:10:26 +0000 Subject: [PATCH 6/8] updated spell backward impl --- reasoning_gym/algorithmic/spell_backward.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/reasoning_gym/algorithmic/spell_backward.py b/reasoning_gym/algorithmic/spell_backward.py index 59b163ee..d1837521 100644 --- a/reasoning_gym/algorithmic/spell_backward.py +++ b/reasoning_gym/algorithmic/spell_backward.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass from random import Random -from typing import Optional +from typing import Any, Dict, Optional from ..data import read_data_file from ..factory import ProceduralDataset, register_dataset @@ -49,5 +49,18 @@ class SpellBackwardDataset(ProceduralDataset): "metadata": {"word": word, "word_len": len(word)}, } + def score_answer(self, answer: Optional[str], entry: Dict[str, Any]) -> float: + reward = 0 + expected_answer = entry["answer"] + if answer is not None: + try: + if expected_answer.lower() == answer.lower(): + reward = 1.0 + else: + reward = 0.05 + except: + reward = 0.01 + return reward + register_dataset("spell_backward", SpellBackwardDataset, SpellBackwardConfig) From d2d4b3a644df0b7a9e49383c76c76eb0791079ba Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Sun, 16 Feb 2025 13:10:43 +0000 Subject: [PATCH 7/8] added another assertion to test --- tests/test_sentence_reordering.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_sentence_reordering.py b/tests/test_sentence_reordering.py index 9ed5b4da..05645c70 100644 --- a/tests/test_sentence_reordering.py +++ b/tests/test_sentence_reordering.py @@ -37,6 +37,7 @@ def test_getitem(dataset, config): assert "metadata" in item assert item["metadata"]["word_count"] >= config.min_words_in_sentence assert item["metadata"]["word_count"] <= config.max_words_in_sentence + assert len(item['answer'].split()) == item['metadata']['word_count'] def test_key_error_in_getitem(dataset): From 6bf2dfa36ccf056afa4a3e4e495aa5f32fe2e362 Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Sun, 16 Feb 2025 16:18:39 +0100 Subject: [PATCH 8/8] formatting --- reasoning_gym/algorithmic/sentence_reordering.py | 15 ++++----------- reasoning_gym/algorithmic/word_ladder.py | 1 - reasoning_gym/arithmetic/basic_arithmetic.py | 9 ++------- tests/test_sentence_reordering.py | 2 +- 4 files changed, 7 insertions(+), 20 deletions(-) diff --git a/reasoning_gym/algorithmic/sentence_reordering.py b/reasoning_gym/algorithmic/sentence_reordering.py index 069dda7c..57f19d6e 100644 --- a/reasoning_gym/algorithmic/sentence_reordering.py +++ b/reasoning_gym/algorithmic/sentence_reordering.py @@ -102,23 +102,16 @@ class SentenceReorderingDataset(ProceduralDataset): goal_words = expected_answer.split() answer_words = answer.split() if len(goal_words) == len(answer_words): - credit = [1 if goal_word.lower() == answer_word.lower() else 0 for goal_word, answer_word in zip(goal_words, answer_words)] + credit = [ + 1 if goal_word.lower() == answer_word.lower() else 0 + for goal_word, answer_word in zip(goal_words, answer_words) + ] reward = sum(credit) / len(credit) else: reward = 0.05 except: reward = 0.01 return reward - - - - - - - - - - register_dataset("sentence_reordering", SentenceReorderingDataset, SentenceReorderingConfig) diff --git a/reasoning_gym/algorithmic/word_ladder.py b/reasoning_gym/algorithmic/word_ladder.py index 40e36291..3be99138 100644 --- a/reasoning_gym/algorithmic/word_ladder.py +++ b/reasoning_gym/algorithmic/word_ladder.py @@ -8,7 +8,6 @@ from typing import Dict, List, Optional, Set, Tuple from ..data import get_data_file_path from ..factory import ProceduralDataset, register_dataset - QUESTION_TEMPLATE = """Transform the word ladder '{start}' to '{end}' by changing one letter at a time. Provide your answer as a comma-separated sequence of uppercase letters without spaces. Each step must be a valid English word.""" diff --git a/reasoning_gym/arithmetic/basic_arithmetic.py b/reasoning_gym/arithmetic/basic_arithmetic.py index efe9d465..c72ff143 100644 --- a/reasoning_gym/arithmetic/basic_arithmetic.py +++ b/reasoning_gym/arithmetic/basic_arithmetic.py @@ -225,16 +225,11 @@ class BasicArithmeticDataset(ProceduralDataset): def _format_question(self, rng: Random, expression: str) -> str: """Format the expression with clear answer positioning""" answer_instruction = "Put your final answer after '=' without additional text." - + if self.config.format_style == "simple": return f"{answer_instruction} Calculate {expression} =" else: - templates = [ - "What is {0} =", - "Solve {0}=", - "Compute {0} =", - "Evaluate: {0} =" - ] + templates = ["What is {0} =", "Solve {0}=", "Compute {0} =", "Evaluate: {0} ="] template = rng.choice(templates).format(expression) return f"{answer_instruction} {template}" diff --git a/tests/test_sentence_reordering.py b/tests/test_sentence_reordering.py index 05645c70..9348ec04 100644 --- a/tests/test_sentence_reordering.py +++ b/tests/test_sentence_reordering.py @@ -37,7 +37,7 @@ def test_getitem(dataset, config): assert "metadata" in item assert item["metadata"]["word_count"] >= config.min_words_in_sentence assert item["metadata"]["word_count"] <= config.max_words_in_sentence - assert len(item['answer'].split()) == item['metadata']['word_count'] + assert len(item["answer"].split()) == item["metadata"]["word_count"] def test_key_error_in_getitem(dataset):