mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-28 17:29:39 +00:00
[poly-reward] add a greedy strategy scoring function for polynomial equations
This commit is contained in:
parent
2ad0965fdc
commit
0dd4c05897
2 changed files with 135 additions and 3 deletions
|
|
@ -1,7 +1,8 @@
|
|||
import random
|
||||
import string
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, List, Dict
|
||||
|
||||
from sympy import Eq, Symbol, expand, solve
|
||||
|
||||
|
|
@ -26,6 +27,9 @@ class PolynomialEquationsConfig:
|
|||
) # Allowed operators between terms, Avoid adding '*' or '/' because they will affect the degree
|
||||
seed: Optional[int] = None
|
||||
size: int = 500
|
||||
# reward function hyperparameters
|
||||
penalty_missing_factor = 0.1
|
||||
penalty_extra_factor = 0.05
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
|
|
@ -40,7 +44,8 @@ class PolynomialEquationsConfig:
|
|||
|
||||
allowed_ops = {"+", "-"}
|
||||
assert len(self.operators) > 0, "operators tuple cannot be empty."
|
||||
assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
|
||||
assert all(
|
||||
op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
|
||||
|
||||
|
||||
class PolynomialEquationsDataset(ProceduralDataset):
|
||||
|
|
@ -146,5 +151,107 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
|||
|
||||
return polynomial_expr
|
||||
|
||||
def _parse_score_to_list(self, answer: Optional[str]) -> List[float]:
|
||||
"""Parses a comma-separated string of scores into a sorted list of floats.
|
||||
|
||||
register_dataset("polynomial_equations", PolynomialEquationsDataset, PolynomialEquationsConfig)
|
||||
This method takes a string containing comma-separated numeric values, attempts to convert each value to a float,
|
||||
and returns a sorted list of these floats. Any values that cannot be converted to a float are ignored.
|
||||
Handles empty strings gracefully.
|
||||
|
||||
Args:
|
||||
answer: An optional string containing comma-separated numeric values. Can be None or an empty string.
|
||||
|
||||
Returns:
|
||||
A sorted list of floats parsed from the input string. Returns an empty list if the input is None, empty, or contains no valid numeric values.
|
||||
"""
|
||||
|
||||
if answer is None or len(answer) == 0: # Handle None or empty input
|
||||
return []
|
||||
|
||||
output_float_vals = []
|
||||
for output_val in answer.split(","):
|
||||
try:
|
||||
# Convert to float, strip whitespace
|
||||
output_float_vals.append(float(output_val.strip()))
|
||||
except ValueError:
|
||||
# Ignore values that cannot be converted to float
|
||||
continue # Continue to the next value in the string
|
||||
|
||||
return sorted(output_float_vals) # Return the sorted list of floats
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
|
||||
"""
|
||||
Score an answer based on its numerical distance to oracle solutions using exponential decay.
|
||||
This function compares a predicted answer (or list of answers) to a set of oracle solutions
|
||||
(also a list of numbers). It calculates a reward based on how close the predicted solutions
|
||||
are to the oracle solutions, using an exponential decay function. It also applies penalties
|
||||
for missing or extra predicted solutions.
|
||||
|
||||
Args:
|
||||
answer: The predicted answer (or a string that can be parsed into a list of numbers).
|
||||
May be None.
|
||||
entry: A dictionary containing the oracle solution(s) under the key "answer"
|
||||
(which can be a string that can be parsed into a list of numbers).
|
||||
|
||||
Returns:
|
||||
A float representing the final score. The score is non-negative.
|
||||
|
||||
Raises:
|
||||
TypeError: If the 'answer' in entry is not a string or list.
|
||||
ValueError: If the 'answer' in entry cannot be parsed to a number.
|
||||
TypeError: If the answer is not a string or list.
|
||||
ValueError: If the answer cannot be parsed to a number.
|
||||
"""
|
||||
oracle_solutions = self._parse_score_to_list(
|
||||
entry["answer"]) # Parse oracle solutions
|
||||
predicted_solutions = self._parse_score_to_list(
|
||||
answer) # Parse predicted solutions
|
||||
|
||||
total_reward = 0.0
|
||||
matched_solutions = 0
|
||||
extra_solutions = 0
|
||||
missing_solutions = 0
|
||||
|
||||
for predicted_solution in predicted_solutions:
|
||||
|
||||
# find the closest matching solution from the oracle solutions.
|
||||
# this is a greedy approach to computing the score
|
||||
matched_distance = float('inf')
|
||||
matched_distance_index = None
|
||||
for oracle_solution_index, oracle_solution in enumerate(oracle_solutions):
|
||||
if matched_distance > abs(predicted_solution - oracle_solution):
|
||||
matched_distance = abs(
|
||||
predicted_solution - oracle_solution)
|
||||
matched_distance_index = oracle_solution_index
|
||||
|
||||
if matched_distance_index is not None:
|
||||
matched_solutions += 1
|
||||
# Remove matched oracle solution
|
||||
oracle_solutions.pop(matched_distance_index)
|
||||
# Exponential decay reward
|
||||
total_reward += math.exp(-matched_distance)
|
||||
else:
|
||||
# Extra predicted solution
|
||||
extra_solutions += 1
|
||||
|
||||
# Count remaining oracle solutions as missing
|
||||
for oracle_solution in oracle_solutions:
|
||||
missing_solutions += 1
|
||||
|
||||
# Calculate penalty for either missing or extra solutions
|
||||
penalty = missing_solutions * self.config.penalty_missing_factor
|
||||
penalty += extra_solutions * self.config.penalty_extra_factor
|
||||
|
||||
if matched_solutions > 0:
|
||||
# normalize the rewards that we found matching solutions for
|
||||
# so that the value is bounded between 0 and 1
|
||||
total_reward = total_reward / matched_solutions
|
||||
|
||||
# Final reward capped at 0
|
||||
final_reward = max(0, total_reward - penalty)
|
||||
|
||||
return final_reward
|
||||
|
||||
|
||||
register_dataset("polynomial_equations",
|
||||
PolynomialEquationsDataset, PolynomialEquationsConfig)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue