[poly-reward] minor updates to the docstrings

This commit is contained in:
rishabhranawat 2025-02-08 21:41:18 -08:00
parent 0f4ab53bd3
commit 7a6f7ea9da

View file

@ -154,15 +154,17 @@ class PolynomialEquationsDataset(ProceduralDataset):
def _parse_score_to_list(self, answer: Optional[str]) -> List[float]:
"""Parses a comma-separated string of scores into a sorted list of floats.
This method takes a string containing comma-separated numeric values, attempts to convert each value to a float,
and returns a sorted list of these floats. Any values that cannot be converted to a float are ignored.
This method takes a string containing comma-separated numeric values,
attempts to convert each value to a float, and returns a sorted list of these floats.
Any values that cannot be converted to a float are ignored.
Handles empty strings gracefully.
Args:
answer: An optional string containing comma-separated numeric values. Can be None or an empty string.
answer: An optional string containing comma-separated numeric values.
Can be None or an empty string.
Returns:
A sorted list of floats parsed from the input string. Returns an empty list if the input is None, empty, or contains no valid numeric values.
A sorted list of floats parsed from the input string.
Returns an empty list if the input is None, empty, or contains no valid numeric values.
"""
if answer is None or len(answer) == 0: # Handle None or empty input
@ -175,7 +177,7 @@ class PolynomialEquationsDataset(ProceduralDataset):
output_float_vals.append(float(output_val.strip()))
except ValueError:
# Ignore values that cannot be converted to float
continue # Continue to the next value in the string
continue
return sorted(output_float_vals) # Return the sorted list of floats
@ -185,7 +187,9 @@ class PolynomialEquationsDataset(ProceduralDataset):
This function compares a predicted answer (or list of answers) to a set of oracle solutions
(also a list of numbers). It calculates a reward based on how close the predicted solutions
are to the oracle solutions, using an exponential decay function. It also applies penalties
for missing or extra predicted solutions.
for missing or extra predicted solutions. The implementation is a greedy algorithm where we
find the closest matching oracle solution for a given predicted solution and only allow an
oracle solution to match once.
Args:
answer: The predicted answer (or a string that can be parsed into a list of numbers).
@ -195,12 +199,6 @@ class PolynomialEquationsDataset(ProceduralDataset):
Returns:
A float representing the final score. The score is non-negative.
Raises:
TypeError: If the 'answer' in entry is not a string or list.
ValueError: If the 'answer' in entry cannot be parsed to a number.
TypeError: If the answer is not a string or list.
ValueError: If the answer cannot be parsed to a number.
"""
oracle_solutions = self._parse_score_to_list(
entry["answer"]) # Parse oracle solutions