[poly-reward] add a greedy strategy scoring function for polynomial equations

This commit is contained in:
rishabhranawat 2025-02-08 21:36:21 -08:00
parent 2ad0965fdc
commit 0dd4c05897
2 changed files with 135 additions and 3 deletions

View file

@ -1,7 +1,8 @@
import random
import string
import math
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Dict
from sympy import Eq, Symbol, expand, solve
@ -26,6 +27,9 @@ class PolynomialEquationsConfig:
) # Allowed operators between terms, Avoid adding '*' or '/' because they will affect the degree
seed: Optional[int] = None
size: int = 500
# reward function hyperparameters
penalty_missing_factor = 0.1
penalty_extra_factor = 0.05
def validate(self) -> None:
"""Validate configuration parameters."""
@ -40,7 +44,8 @@ class PolynomialEquationsConfig:
allowed_ops = {"+", "-"}
assert len(self.operators) > 0, "operators tuple cannot be empty."
assert all(op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
assert all(
op in allowed_ops for op in self.operators), "Invalid operator found. Must be a subset of {+, -}."
class PolynomialEquationsDataset(ProceduralDataset):
@ -146,5 +151,107 @@ class PolynomialEquationsDataset(ProceduralDataset):
return polynomial_expr
def _parse_score_to_list(self, answer: Optional[str]) -> List[float]:
"""Parses a comma-separated string of scores into a sorted list of floats.
register_dataset("polynomial_equations", PolynomialEquationsDataset, PolynomialEquationsConfig)
This method takes a string containing comma-separated numeric values, attempts to convert each value to a float,
and returns a sorted list of these floats. Any values that cannot be converted to a float are ignored.
Handles empty strings gracefully.
Args:
answer: An optional string containing comma-separated numeric values. Can be None or an empty string.
Returns:
A sorted list of floats parsed from the input string. Returns an empty list if the input is None, empty, or contains no valid numeric values.
"""
if answer is None or len(answer) == 0: # Handle None or empty input
return []
output_float_vals = []
for output_val in answer.split(","):
try:
# Convert to float, strip whitespace
output_float_vals.append(float(output_val.strip()))
except ValueError:
# Ignore values that cannot be converted to float
continue # Continue to the next value in the string
return sorted(output_float_vals) # Return the sorted list of floats
def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
"""
Score an answer based on its numerical distance to oracle solutions using exponential decay.
This function compares a predicted answer (or list of answers) to a set of oracle solutions
(also a list of numbers). It calculates a reward based on how close the predicted solutions
are to the oracle solutions, using an exponential decay function. It also applies penalties
for missing or extra predicted solutions.
Args:
answer: The predicted answer (or a string that can be parsed into a list of numbers).
May be None.
entry: A dictionary containing the oracle solution(s) under the key "answer"
(which can be a string that can be parsed into a list of numbers).
Returns:
A float representing the final score. The score is non-negative.
Raises:
TypeError: If the 'answer' in entry is not a string or list.
ValueError: If the 'answer' in entry cannot be parsed to a number.
TypeError: If the answer is not a string or list.
ValueError: If the answer cannot be parsed to a number.
"""
oracle_solutions = self._parse_score_to_list(
entry["answer"]) # Parse oracle solutions
predicted_solutions = self._parse_score_to_list(
answer) # Parse predicted solutions
total_reward = 0.0
matched_solutions = 0
extra_solutions = 0
missing_solutions = 0
for predicted_solution in predicted_solutions:
# find the closest matching solution from the oracle solutions.
# this is a greedy approach to computing the score
matched_distance = float('inf')
matched_distance_index = None
for oracle_solution_index, oracle_solution in enumerate(oracle_solutions):
if matched_distance > abs(predicted_solution - oracle_solution):
matched_distance = abs(
predicted_solution - oracle_solution)
matched_distance_index = oracle_solution_index
if matched_distance_index is not None:
matched_solutions += 1
# Remove matched oracle solution
oracle_solutions.pop(matched_distance_index)
# Exponential decay reward
total_reward += math.exp(-matched_distance)
else:
# Extra predicted solution
extra_solutions += 1
# Count remaining oracle solutions as missing
for oracle_solution in oracle_solutions:
missing_solutions += 1
# Calculate penalty for either missing or extra solutions
penalty = missing_solutions * self.config.penalty_missing_factor
penalty += extra_solutions * self.config.penalty_extra_factor
if matched_solutions > 0:
# normalize the rewards that we found matching solutions for
# so that the value is bounded between 0 and 1
total_reward = total_reward / matched_solutions
# Final reward capped at 0
final_reward = max(0, total_reward - penalty)
return final_reward
register_dataset("polynomial_equations",
PolynomialEquationsDataset, PolynomialEquationsConfig)