diff --git a/reasoning_gym/algorithmic/number_sorting.py b/reasoning_gym/algorithmic/number_sorting.py index cf572668..d6e46437 100644 --- a/reasoning_gym/algorithmic/number_sorting.py +++ b/reasoning_gym/algorithmic/number_sorting.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from random import Random from typing import Any, Optional +import numpy as np + from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset @@ -44,12 +46,6 @@ Please follow the instruction below: ## 2. Convert all numbers in the square brackets as strings. For example, ['-69', '-13', '1', '7', '11', '43', '59', '61'] """ - def _format_number(self, num: float, decimals: int) -> str: - """Format number with specified decimal places""" - formatted = f"{num:.{decimals}f}" - # Reparse to ensure exact decimal representation - return f"{float(formatted):.{decimals}f}" - def _generate_numbers(self, rng: Random, count: int) -> tuple[list[float], list[str]]: """Generate list of numbers and their string representations""" numbers = [] @@ -58,11 +54,9 @@ Please follow the instruction below: for _ in range(count): num = rng.uniform(self.config.min_value, self.config.max_value) decimals = rng.randint(self.config.min_decimals, self.config.max_decimals) - num_str = self._format_number(num, decimals) - # Reparse to ensure exact value - num = float(num_str) + num = np.round(num, decimals) numbers.append(num) - number_strs.append(num_str) + number_strs.append(str(num)) return numbers, number_strs @@ -78,9 +72,8 @@ Please follow the instruction below: desc_numbers = sorted(numbers, reverse=True) # Format answers as string lists - decimals = len(number_strs[0].split(".")[-1]) if "." in number_strs[0] else 0 - asc_answer = [self._format_number(n, decimals) for n in asc_numbers] - desc_answer = [self._format_number(n, decimals) for n in desc_numbers] + asc_answer = [str(n) for n in asc_numbers] + desc_answer = [str(n) for n in desc_numbers] # Randomly choose ascending or descending is_ascending = rng.choice([True, False]) @@ -158,7 +151,7 @@ Please follow the instruction below: return 0.0 # Check if the values are close enough (allowing for small rounding differences) - tolerance = 0.1 # Increased tolerance to handle decimal differences + tolerance = 1 # Increased tolerance to handle decimal differences for i in range(len(user_floats)): if abs(user_floats[i] - expected_floats[i]) > tolerance: return 0.0 diff --git a/reasoning_gym/algorithmic/spell_backward.py b/reasoning_gym/algorithmic/spell_backward.py index a73acce6..0de8d5f2 100644 --- a/reasoning_gym/algorithmic/spell_backward.py +++ b/reasoning_gym/algorithmic/spell_backward.py @@ -72,7 +72,7 @@ class SpellBackwardDataset(ProceduralDataset): expected_answer = expected_answer.lower() answer = answer.lower() if expected_answer == answer: - reward = 1.0 + return 1.0 else: answer_len = len(expected_answer) for i in range(len(expected_answer)): @@ -83,7 +83,8 @@ class SpellBackwardDataset(ProceduralDataset): continue else: break - + if reward == 1.0: + reward -= 0.2 except: reward = 0.0 return reward diff --git a/reasoning_gym/algorithmic/word_sorting.py b/reasoning_gym/algorithmic/word_sorting.py index 1fc20e28..0e72d20b 100644 --- a/reasoning_gym/algorithmic/word_sorting.py +++ b/reasoning_gym/algorithmic/word_sorting.py @@ -125,14 +125,25 @@ class WordSortingDataset(ProceduralDataset): def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: oracle_answer = entry["metadata"]["sorted_words"] - if answer is not None and len(answer) > 0: - parsed_answer = [word.strip() for word in re.split(r",\s*", answer)] - if parsed_answer == oracle_answer: - return 1.0 - elif sorted(parsed_answer) == oracle_answer: - return 0.2 - return 0.0 + if not answer: + return 0.0 + + parsed_answer = [word.strip() for word in re.split(r",\s*", answer)] + + if parsed_answer == oracle_answer: + return 1.0 + + correct_positions = sum( + 1 for i, word in enumerate(parsed_answer) if i < len(oracle_answer) and word == oracle_answer[i] + ) + + partial_score = correct_positions / len(oracle_answer) + + if sorted(parsed_answer) == sorted(oracle_answer): + partial_score = max(partial_score, 0.2) + + return partial_score class WordSortingCurriculum(BaseCurriculum):