Add input prediction

This commit is contained in:
Oliver 2025-02-23 20:27:27 +00:00
parent 489dea7267
commit 40d7dfdb5f

View file

@ -27,6 +27,26 @@ Tip: Here is a reference code snippet for this question. You can refer to this c
{3} {3}
""" """
INPUT_PREDICTION_PROMPT_TEMPLATE = """
You are given a question that requires some input and output variables as follows:
{0}
The input and output requirements are as follows:
{1}
Given the following output:
{2}
Can you predict a feasible input without writing any code? Please reason and put your final answer in the following json format: "input": <your input>, where <your input> should be a dictionary, even if the there is only one input variable, with keys strictly matching the input variables' names as specified.
Tip: Here is a reference code snippet for this question. You can refer to this code to guide your reasoning but not copy spans of code directly.
{3}
"""
# TODO: also add input prediction prompt # TODO: also add input prediction prompt
@ -36,6 +56,7 @@ class CodeIOConfig:
seed: Optional[int] = None seed: Optional[int] = None
size: int = 500 size: int = 500
input_prediction_probability: float = 0.5
def validate(self) -> None: def validate(self) -> None:
"""Validate configuration parameters""" """Validate configuration parameters"""
@ -95,11 +116,12 @@ class CodeIODataset(ProceduralDataset):
input_data, output_data = self._generate_io_pairs(reference_code, input_generator_code, num_pairs=1)[0] input_data, output_data = self._generate_io_pairs(reference_code, input_generator_code, num_pairs=1)[0]
# TODO add chance of input prediction rather than output if rng.random() < self.config.input_prediction_probability:
question = OUTPUT_PREDICTION_PROMPT_TEMPLATE.format(query, parameters, input_data, reference_code)
question = OUTPUT_PREDICTION_PROMPT_TEMPLATE.format(query, parameters, input_data, reference_code) solution = output_data
# TODO: consider changing format here else:
solution = output_data question = INPUT_PREDICTION_PROMPT_TEMPLATE.format(query, parameters, output_data, reference_code)
solution = input_data
return { return {
"question": question, "question": question,