mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Initial scoring algo for codeio
This commit is contained in:
parent
1795c8ea7a
commit
43daec67ea
1 changed files with 29 additions and 7 deletions
|
|
@ -1,4 +1,3 @@
|
|||
# TODO: consider whether this belongs in the "code" directory
|
||||
import gzip
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -21,7 +20,7 @@ Given the following input:
|
|||
|
||||
{2}
|
||||
|
||||
Can you predict the output without writing any code? Please think and then provide only the exact output as your final answer, which should strictly match the output requirement as specified.
|
||||
Can you predict the output without writing any code? Please think and then provide the exact output in the form of a JSON object as your final answer. The keys and values of the object should strictly match the output requirement as specified.
|
||||
|
||||
Tip: Here is a reference code snippet for this question. You can refer to this code to guide your reasoning but not copy spans of code directly.
|
||||
|
||||
|
|
@ -41,7 +40,7 @@ Given the following output:
|
|||
|
||||
{2}
|
||||
|
||||
Can you predict a feasible input without writing any code? Please reason and put your final answer in the following json format: "input": <your input>, where <your input> should be a dictionary, even if the there is only one input variable, with keys strictly matching the input variables' names as specified.
|
||||
Can you predict a feasible input without writing any code? Please reason and put your final answer in the form of a JSON object, even if the there is only one input variable, with keys strictly matching the input variables' names as specified.
|
||||
|
||||
Tip: Here is a reference code snippet for this question. You can refer to this code to guide your reasoning but not copy spans of code directly.
|
||||
|
||||
|
|
@ -120,10 +119,10 @@ class CodeIODataset(ProceduralDataset):
|
|||
|
||||
if rng.random() < self.config.input_prediction_probability:
|
||||
question = OUTPUT_PREDICTION_PROMPT_TEMPLATE.format(query, parameters, input_data, reference_code)
|
||||
solution = output_data
|
||||
solution = json.dumps(output_data)
|
||||
else:
|
||||
question = INPUT_PREDICTION_PROMPT_TEMPLATE.format(query, parameters, output_data, reference_code)
|
||||
solution = input_data
|
||||
solution = json.dumps(input_data)
|
||||
|
||||
return {
|
||||
"question": question,
|
||||
|
|
@ -132,15 +131,38 @@ class CodeIODataset(ProceduralDataset):
|
|||
}
|
||||
|
||||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
|
||||
# TODO: better answer scoring
|
||||
# TODO: this scoring could definitely be refined
|
||||
oracle_answer = entry["answer"].strip()
|
||||
reward = 0.0
|
||||
if answer is not None and len(answer) > 0:
|
||||
answer = answer.strip()
|
||||
if answer == oracle_answer:
|
||||
reward = 1.0
|
||||
elif "{" in answer and "}" in answer:
|
||||
# Check if the answer contains a correct format JSON object somewhere
|
||||
# But penalise for length & accuracy
|
||||
ans_first_open, ans_last_close = answer.index("{"), answer.rindex("}")
|
||||
extra_chars = len(answer[:ans_first_open]) + len(answer[ans_last_close + 1 :])
|
||||
|
||||
try:
|
||||
answer_dict = json.loads(answer[ans_first_open : ans_last_close + 1])
|
||||
oracle_dict = json.loads(oracle_answer)
|
||||
if answer_dict == oracle_dict:
|
||||
# 0.5 is arbitrary here, but the answers are very short so it seems harsh to penalize too much
|
||||
# e.g. if oracle is {"steps": "3"} and answer is "The correct answer is: {"steps": "3"}"
|
||||
reward = max(len(oracle_answer) / (len(oracle_answer) + 0.5 * extra_chars), 0.2)
|
||||
elif answer_dict.keys() == oracle_dict.keys():
|
||||
# Wrong answer, but at least the right format
|
||||
reward = 0.1
|
||||
else:
|
||||
# At least we got a JSON object, I guess?
|
||||
reward = 0.05
|
||||
except json.JSONDecodeError:
|
||||
if oracle_answer in answer:
|
||||
reward = len(oracle_answer) / len(answer)
|
||||
elif oracle_answer in answer:
|
||||
reward = len(oracle_answer) / len(answer)
|
||||
# max() to avoid penalising too heavily, since correct answers are short here
|
||||
reward = max(len(oracle_answer) / len(answer), 0.2)
|
||||
else:
|
||||
reward = 0.01
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue