diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index 2a20fdf4..96498148 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -154,15 +154,17 @@ class PolynomialEquationsDataset(ProceduralDataset): def _parse_score_to_list(self, answer: Optional[str]) -> List[float]: """Parses a comma-separated string of scores into a sorted list of floats. - This method takes a string containing comma-separated numeric values, attempts to convert each value to a float, - and returns a sorted list of these floats. Any values that cannot be converted to a float are ignored. + This method takes a string containing comma-separated numeric values, + attempts to convert each value to a float, and returns a sorted list of these floats. + Any values that cannot be converted to a float are ignored. Handles empty strings gracefully. Args: - answer: An optional string containing comma-separated numeric values. Can be None or an empty string. - + answer: An optional string containing comma-separated numeric values. + Can be None or an empty string. Returns: - A sorted list of floats parsed from the input string. Returns an empty list if the input is None, empty, or contains no valid numeric values. + A sorted list of floats parsed from the input string. + Returns an empty list if the input is None, empty, or contains no valid numeric values. """ if answer is None or len(answer) == 0: # Handle None or empty input @@ -175,7 +177,7 @@ class PolynomialEquationsDataset(ProceduralDataset): output_float_vals.append(float(output_val.strip())) except ValueError: # Ignore values that cannot be converted to float - continue # Continue to the next value in the string + continue return sorted(output_float_vals) # Return the sorted list of floats @@ -185,7 +187,9 @@ class PolynomialEquationsDataset(ProceduralDataset): This function compares a predicted answer (or list of answers) to a set of oracle solutions (also a list of numbers). It calculates a reward based on how close the predicted solutions are to the oracle solutions, using an exponential decay function. It also applies penalties - for missing or extra predicted solutions. + for missing or extra predicted solutions. The implementation is a greedy algorithm where we + find the closest matching oracle solution for a given predicted solution and only allow an + oracle solution to match once. Args: answer: The predicted answer (or a string that can be parsed into a list of numbers). @@ -195,12 +199,6 @@ class PolynomialEquationsDataset(ProceduralDataset): Returns: A float representing the final score. The score is non-negative. - - Raises: - TypeError: If the 'answer' in entry is not a string or list. - ValueError: If the 'answer' in entry cannot be parsed to a number. - TypeError: If the answer is not a string or list. - ValueError: If the answer cannot be parsed to a number. """ oracle_solutions = self._parse_score_to_list( entry["answer"]) # Parse oracle solutions