mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-05-01 17:45:24 +00:00
[poly-reward] minor updates to the docstrings
This commit is contained in:
parent
0f4ab53bd3
commit
7a6f7ea9da
1 changed files with 11 additions and 13 deletions
|
|
@ -154,15 +154,17 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
|||
def _parse_score_to_list(self, answer: Optional[str]) -> List[float]:
|
||||
"""Parses a comma-separated string of scores into a sorted list of floats.
|
||||
|
||||
This method takes a string containing comma-separated numeric values, attempts to convert each value to a float,
|
||||
and returns a sorted list of these floats. Any values that cannot be converted to a float are ignored.
|
||||
This method takes a string containing comma-separated numeric values,
|
||||
attempts to convert each value to a float, and returns a sorted list of these floats.
|
||||
Any values that cannot be converted to a float are ignored.
|
||||
Handles empty strings gracefully.
|
||||
|
||||
Args:
|
||||
answer: An optional string containing comma-separated numeric values. Can be None or an empty string.
|
||||
|
||||
answer: An optional string containing comma-separated numeric values.
|
||||
Can be None or an empty string.
|
||||
Returns:
|
||||
A sorted list of floats parsed from the input string. Returns an empty list if the input is None, empty, or contains no valid numeric values.
|
||||
A sorted list of floats parsed from the input string.
|
||||
Returns an empty list if the input is None, empty, or contains no valid numeric values.
|
||||
"""
|
||||
|
||||
if answer is None or len(answer) == 0: # Handle None or empty input
|
||||
|
|
@ -175,7 +177,7 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
|||
output_float_vals.append(float(output_val.strip()))
|
||||
except ValueError:
|
||||
# Ignore values that cannot be converted to float
|
||||
continue # Continue to the next value in the string
|
||||
continue
|
||||
|
||||
return sorted(output_float_vals) # Return the sorted list of floats
|
||||
|
||||
|
|
@ -185,7 +187,9 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
|||
This function compares a predicted answer (or list of answers) to a set of oracle solutions
|
||||
(also a list of numbers). It calculates a reward based on how close the predicted solutions
|
||||
are to the oracle solutions, using an exponential decay function. It also applies penalties
|
||||
for missing or extra predicted solutions.
|
||||
for missing or extra predicted solutions. The implementation is a greedy algorithm where we
|
||||
find the closest matching oracle solution for a given predicted solution and only allow an
|
||||
oracle solution to match once.
|
||||
|
||||
Args:
|
||||
answer: The predicted answer (or a string that can be parsed into a list of numbers).
|
||||
|
|
@ -195,12 +199,6 @@ class PolynomialEquationsDataset(ProceduralDataset):
|
|||
|
||||
Returns:
|
||||
A float representing the final score. The score is non-negative.
|
||||
|
||||
Raises:
|
||||
TypeError: If the 'answer' in entry is not a string or list.
|
||||
ValueError: If the 'answer' in entry cannot be parsed to a number.
|
||||
TypeError: If the answer is not a string or list.
|
||||
ValueError: If the answer cannot be parsed to a number.
|
||||
"""
|
||||
oracle_solutions = self._parse_score_to_list(
|
||||
entry["answer"]) # Parse oracle solutions
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue