updated datasets

This commit is contained in:
joesharratt1229 2025-04-01 16:11:31 +00:00
parent 9f9f816902
commit 37bbd97191
3 changed files with 28 additions and 23 deletions

View file

@ -5,6 +5,8 @@ from dataclasses import dataclass
from random import Random from random import Random
from typing import Any, Optional from typing import Any, Optional
import numpy as np
from ..coaching import BaseCurriculum, RangeAttributeDefinition from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset from ..factory import ProceduralDataset, register_dataset
@ -44,12 +46,6 @@ Please follow the instruction below:
## 2. Convert all numbers in the square brackets as strings. For example, ['-69', '-13', '1', '7', '11', '43', '59', '61'] ## 2. Convert all numbers in the square brackets as strings. For example, ['-69', '-13', '1', '7', '11', '43', '59', '61']
""" """
def _format_number(self, num: float, decimals: int) -> str:
"""Format number with specified decimal places"""
formatted = f"{num:.{decimals}f}"
# Reparse to ensure exact decimal representation
return f"{float(formatted):.{decimals}f}"
def _generate_numbers(self, rng: Random, count: int) -> tuple[list[float], list[str]]: def _generate_numbers(self, rng: Random, count: int) -> tuple[list[float], list[str]]:
"""Generate list of numbers and their string representations""" """Generate list of numbers and their string representations"""
numbers = [] numbers = []
@ -58,11 +54,9 @@ Please follow the instruction below:
for _ in range(count): for _ in range(count):
num = rng.uniform(self.config.min_value, self.config.max_value) num = rng.uniform(self.config.min_value, self.config.max_value)
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals) decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
num_str = self._format_number(num, decimals) num = np.round(num, decimals)
# Reparse to ensure exact value
num = float(num_str)
numbers.append(num) numbers.append(num)
number_strs.append(num_str) number_strs.append(str(num))
return numbers, number_strs return numbers, number_strs
@ -78,9 +72,8 @@ Please follow the instruction below:
desc_numbers = sorted(numbers, reverse=True) desc_numbers = sorted(numbers, reverse=True)
# Format answers as string lists # Format answers as string lists
decimals = len(number_strs[0].split(".")[-1]) if "." in number_strs[0] else 0 asc_answer = [str(n) for n in asc_numbers]
asc_answer = [self._format_number(n, decimals) for n in asc_numbers] desc_answer = [str(n) for n in desc_numbers]
desc_answer = [self._format_number(n, decimals) for n in desc_numbers]
# Randomly choose ascending or descending # Randomly choose ascending or descending
is_ascending = rng.choice([True, False]) is_ascending = rng.choice([True, False])
@ -158,7 +151,7 @@ Please follow the instruction below:
return 0.0 return 0.0
# Check if the values are close enough (allowing for small rounding differences) # Check if the values are close enough (allowing for small rounding differences)
tolerance = 0.1 # Increased tolerance to handle decimal differences tolerance = 1 # Increased tolerance to handle decimal differences
for i in range(len(user_floats)): for i in range(len(user_floats)):
if abs(user_floats[i] - expected_floats[i]) > tolerance: if abs(user_floats[i] - expected_floats[i]) > tolerance:
return 0.0 return 0.0

View file

@ -72,7 +72,7 @@ class SpellBackwardDataset(ProceduralDataset):
expected_answer = expected_answer.lower() expected_answer = expected_answer.lower()
answer = answer.lower() answer = answer.lower()
if expected_answer == answer: if expected_answer == answer:
reward = 1.0 return 1.0
else: else:
answer_len = len(expected_answer) answer_len = len(expected_answer)
for i in range(len(expected_answer)): for i in range(len(expected_answer)):
@ -83,7 +83,8 @@ class SpellBackwardDataset(ProceduralDataset):
continue continue
else: else:
break break
if reward == 1.0:
reward -= 0.2
except: except:
reward = 0.0 reward = 0.0
return reward return reward

View file

@ -125,14 +125,25 @@ class WordSortingDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["metadata"]["sorted_words"] oracle_answer = entry["metadata"]["sorted_words"]
if answer is not None and len(answer) > 0:
if not answer:
return 0.0
parsed_answer = [word.strip() for word in re.split(r",\s*", answer)] parsed_answer = [word.strip() for word in re.split(r",\s*", answer)]
if parsed_answer == oracle_answer: if parsed_answer == oracle_answer:
return 1.0 return 1.0
elif sorted(parsed_answer) == oracle_answer:
return 0.2
return 0.0 correct_positions = sum(
1 for i, word in enumerate(parsed_answer) if i < len(oracle_answer) and word == oracle_answer[i]
)
partial_score = correct_positions / len(oracle_answer)
if sorted(parsed_answer) == sorted(oracle_answer):
partial_score = max(partial_score, 0.2)
return partial_score
class WordSortingCurriculum(BaseCurriculum): class WordSortingCurriculum(BaseCurriculum):