diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index ed7e857f..a1822958 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -1,7 +1,8 @@ +import math import random import string from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Dict, List, Optional, Tuple from sympy import Eq, Symbol, expand, solve @@ -26,6 +27,9 @@ class PolynomialEquationsConfig: ) # Allowed operators between terms, Avoid adding '*' or '/' because they will affect the degree seed: Optional[int] = None size: int = 500 + # reward function hyperparameters + penalty_missing_factor = 0.1 + penalty_extra_factor = 0.05 def validate(self) -> None: """Validate configuration parameters.""" @@ -146,5 +150,101 @@ class PolynomialEquationsDataset(ProceduralDataset): return polynomial_expr + def _parse_score_to_list(self, answer: Optional[str]) -> List[float]: + """Parses a comma-separated string of scores into a sorted list of floats. + + This method takes a string containing comma-separated numeric values, + attempts to convert each value to a float, and returns a sorted list of these floats. + Any values that cannot be converted to a float are ignored. + Handles empty strings gracefully. + + Args: + answer: An optional string containing comma-separated numeric values. + Can be None or an empty string. + Returns: + A sorted list of floats parsed from the input string. + Returns an empty list if the input is None, empty, or contains no valid numeric values. + """ + + if answer is None or len(answer) == 0: # Handle None or empty input + return [] + + output_float_vals = [] + for output_val in answer.split(","): + try: + # Convert to float, strip whitespace + output_float_vals.append(float(output_val.strip())) + except ValueError: + # Ignore values that cannot be converted to float + continue + + return sorted(output_float_vals) # Return the sorted list of floats + + def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: + """ + Score an answer based on its numerical distance to oracle solutions using exponential decay. + This function compares a predicted answer (or list of answers) to a set of oracle solutions + (also a list of numbers). It calculates a reward based on how close the predicted solutions + are to the oracle solutions, using an exponential decay function. It also applies penalties + for missing or extra predicted solutions. The implementation is a greedy algorithm where we + find the closest matching oracle solution for a given predicted solution and only allow an + oracle solution to match once. + + Args: + answer: The predicted answer (or a string that can be parsed into a list of numbers). + May be None. + entry: A dictionary containing the oracle solution(s) under the key "answer" + (which can be a string that can be parsed into a list of numbers). + + Returns: + A float representing the final score. The score is non-negative. + """ + oracle_solutions = self._parse_score_to_list(entry["answer"]) # Parse oracle solutions + predicted_solutions = self._parse_score_to_list(answer) # Parse predicted solutions + + total_reward = 0.0 + matched_solutions = 0 + extra_solutions = 0 + missing_solutions = 0 + + for predicted_solution in predicted_solutions: + + # find the closest matching solution from the oracle solutions. + # this is a greedy approach to computing the score + matched_distance = float("inf") + matched_distance_index = None + for oracle_solution_index, oracle_solution in enumerate(oracle_solutions): + if matched_distance > abs(predicted_solution - oracle_solution): + matched_distance = abs(predicted_solution - oracle_solution) + matched_distance_index = oracle_solution_index + + if matched_distance_index is not None: + matched_solutions += 1 + # Remove matched oracle solution + oracle_solutions.pop(matched_distance_index) + # Exponential decay reward + total_reward += math.exp(-matched_distance) + else: + # Extra predicted solution + extra_solutions += 1 + + # Count remaining oracle solutions as missing + for oracle_solution in oracle_solutions: + missing_solutions += 1 + + # Calculate penalty for either missing or extra solutions + penalty = missing_solutions * self.config.penalty_missing_factor + penalty += extra_solutions * self.config.penalty_extra_factor + + if matched_solutions > 0: + # normalize the rewards that we found matching solutions for + # so that the value is bounded between 0 and 1 + total_reward = total_reward / matched_solutions + + # Final reward capped at 0 + final_reward = max(0, total_reward - penalty) + + return final_reward + register_dataset("polynomial_equations", PolynomialEquationsDataset, PolynomialEquationsConfig) diff --git a/tests/test_polynomial_equations.py b/tests/test_polynomial_equations.py index 6e1bb0c0..e7caf654 100644 --- a/tests/test_polynomial_equations.py +++ b/tests/test_polynomial_equations.py @@ -1,4 +1,5 @@ import pytest +from pytest import approx from sympy import Symbol, sympify from reasoning_gym import create_dataset @@ -115,3 +116,25 @@ def test_polynomial_solutions_evaluation(): f"Solution {solution} does not satisfy the polynomial {poly_str}. " f"Evaluated value: {evaluated_value}" ) + + +@pytest.mark.parametrize( + "oracle_answer, predicted_answer, expected_reward", + [ + ("4,-4.12", "4,-4.12", 1.0), # Exact match + ("4,-4.12", "4.0001,-4.120001", approx(0.9999, rel=1e-3)), # Very close match + ("4,-4.12", "4.1,-4.2", approx(0.9139, rel=1e-3)), + ("4,8", "4", approx(0.9, rel=1e-3)), # Missing an oracle solution -> missing solution penalty applies + ("4", "4,8", approx(0.95, rel=1e-3)), # extra solution -> extra solution penalty + ("-1,-2", "1,4", approx(0.06890, rel=1e-3)), # -1 matched w/ 1 and -2 matched w/ 4 + ("", "1", approx(0, rel=1e-4)), # oracle no solution, predicted extra solution + ("1", "", approx(0, rel=1e-4)), # oracle has a solution, predicted no solution + ], +) +def test_polynomial_solutions_score_answer(oracle_answer, predicted_answer, expected_reward): + # You might want to parameterize cfg as well + cfg = PolynomialEquationsConfig(seed=999, size=3) + ds = PolynomialEquationsDataset(cfg) + + actual_reward = ds.score_answer(predicted_answer, {"answer": oracle_answer}) + assert actual_reward == pytest.approx(expected_reward, rel=1e-3) # Fuzzy comparison for floats