mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix lintings
This commit is contained in:
parent
c980157a5a
commit
b9dac318ee
2 changed files with 59 additions and 35 deletions
|
|
@ -27,10 +27,22 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|||
|
||||
system_prompt = ""
|
||||
|
||||
system_prompt += f"You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests."
|
||||
system_prompt += (
|
||||
"You are an expert Python programmer. You will be given a question "
|
||||
"(problem specification) and will generate a correct Python program "
|
||||
"that matches the specification and passes all tests."
|
||||
)
|
||||
|
||||
FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters."
|
||||
FORMATTING_WITHOUT_STARTER_CODE = "Read the inputs from stdin, solve the problem, and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs it reads the inputs runs the algorithm and writes output to STDOUT."
|
||||
FORMATTING_MESSAGE_WITH_STARTER_CODE = (
|
||||
"You will use the following starter code to write the solution to the "
|
||||
"problem and enclose your code within delimiters."
|
||||
)
|
||||
FORMATTING_WITHOUT_STARTER_CODE = (
|
||||
"Read the inputs from stdin, solve the problem, and write the answer to "
|
||||
"stdout (do not directly test on the sample inputs). Enclose your code "
|
||||
"within delimiters as follows. Ensure that when the python program runs "
|
||||
"it reads the inputs runs the algorithm and writes output to STDOUT."
|
||||
)
|
||||
|
||||
async_semaphore = asyncio.Semaphore(100)
|
||||
|
||||
|
|
@ -45,8 +57,8 @@ def get_prompt(question, problem_type, starter_code=None):
|
|||
pass
|
||||
else:
|
||||
prompt += f"{FORMATTING_WITHOUT_STARTER_CODE}\n"
|
||||
prompt += f"```python\n# YOUR CODE HERE\n```\n\n"
|
||||
prompt += f"### Answer: (use the provided format with backticks)\n\n"
|
||||
prompt += "```python\n# YOUR CODE HERE\n```\n\n"
|
||||
prompt += "### Answer: (use the provided format with backticks)\n\n"
|
||||
return prompt
|
||||
|
||||
|
||||
|
|
@ -172,7 +184,7 @@ class CodingEnv(BaseEnv):
|
|||
|
||||
chat_completion = await self.server.chat_completion(
|
||||
messages=[system_msg, user_msg],
|
||||
n=1, ### CHANGE
|
||||
n=1, # CHANGE
|
||||
max_tokens=max_tokens,
|
||||
temperature=temp,
|
||||
top_p=top_p,
|
||||
|
|
@ -526,7 +538,9 @@ class CodingEnv(BaseEnv):
|
|||
old_completion_idx = 0
|
||||
new_completion_idx = 0
|
||||
|
||||
# assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths), f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH"
|
||||
# assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths),
|
||||
# f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and
|
||||
# COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH"
|
||||
print(
|
||||
"TEMP METRICS LENGTHS",
|
||||
len(self.temp_metrics["num_correct"]),
|
||||
|
|
@ -567,9 +581,10 @@ class CodingEnv(BaseEnv):
|
|||
|
||||
new_completion_idx += self.config.batch_size
|
||||
|
||||
assert old_completion_idx <= len(
|
||||
self.completion_lengths
|
||||
), f"OLD COMPLETION IDX {old_completion_idx} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER"
|
||||
assert old_completion_idx <= len(self.completion_lengths), (
|
||||
f"OLD COMPLETION IDX {old_completion_idx} and "
|
||||
f"COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER"
|
||||
)
|
||||
|
||||
wandb_metrics["train/completion_lengths"] = sum(
|
||||
self.completion_lengths[old_completion_idx:new_completion_idx]
|
||||
|
|
@ -653,7 +668,7 @@ class CodingEnv(BaseEnv):
|
|||
if self.config.dataset_name == "deepmind":
|
||||
self.train = load_dataset(
|
||||
"deepmind/code_contests", split="train"
|
||||
) ### CHANGE
|
||||
) # CHANGE
|
||||
else:
|
||||
self.train = load_dataset(
|
||||
"NousResearch/RLVR_Coding_Problems", split="train"
|
||||
|
|
@ -665,7 +680,7 @@ class CodingEnv(BaseEnv):
|
|||
self.test[-1]["idx"] = len(self.test) - 1
|
||||
self.test[-1]["split"] = "test"
|
||||
|
||||
self.iter = 0 ### CHANGE
|
||||
self.iter = 0 # CHANGE
|
||||
self.lock = asyncio.Lock()
|
||||
self.lock2 = asyncio.Lock()
|
||||
|
||||
|
|
@ -749,7 +764,6 @@ class CodingEnv(BaseEnv):
|
|||
scores["scores"] = list()
|
||||
scores["overrides"] = list()
|
||||
|
||||
results = []
|
||||
codes = []
|
||||
lengths = []
|
||||
errors = []
|
||||
|
|
@ -780,7 +794,7 @@ class CodingEnv(BaseEnv):
|
|||
|
||||
codes.append(code)
|
||||
|
||||
if code == None:
|
||||
if code is None:
|
||||
scores["scores"].append(-1.0)
|
||||
errors.append({"error_message": "No code"})
|
||||
continue
|
||||
|
|
@ -788,7 +802,7 @@ class CodingEnv(BaseEnv):
|
|||
try:
|
||||
async with async_semaphore:
|
||||
res, metadata = await run_test.remote.aio(test_cases, code)
|
||||
except modal.exception.RemoteError as e:
|
||||
except modal.exception.RemoteError:
|
||||
res = [False]
|
||||
rprint("index", cur_id)
|
||||
rprint(test_cases["tests"]["fn_name"])
|
||||
|
|
|
|||
|
|
@ -21,9 +21,19 @@ from types import ModuleType
|
|||
from unittest.mock import mock_open, patch
|
||||
|
||||
import modal
|
||||
import numpy as np
|
||||
|
||||
import_string = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n"
|
||||
import_string = (
|
||||
"from string import *\nfrom re import *\nfrom datetime import *\n"
|
||||
"from collections import *\nfrom heapq import *\nfrom bisect import *\n"
|
||||
"from copy import *\nfrom math import *\nfrom random import *\n"
|
||||
"from statistics import *\nfrom itertools import *\nfrom functools import *\n"
|
||||
"from operator import *\nfrom io import *\nfrom sys import *\n"
|
||||
"from json import *\nfrom builtins import *\nfrom typing import *\n"
|
||||
"import string\nimport re\nimport datetime\nimport collections\n"
|
||||
"import heapq\nimport bisect\nimport copy\nimport math\nimport random\n"
|
||||
"import statistics\nimport itertools\nimport functools\nimport operator\n"
|
||||
"import io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n"
|
||||
)
|
||||
|
||||
image = modal.Image.debian_slim(python_version="3.10").pip_install("numpy")
|
||||
app = modal.App("joeli-lcb", image=image)
|
||||
|
|
@ -121,7 +131,7 @@ def clean_if_name(code: str) -> str:
|
|||
code = (
|
||||
ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) # type: ignore
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return code
|
||||
|
|
@ -155,7 +165,7 @@ def make_function(code: str) -> str:
|
|||
+ ast.unparse(function_ast) # type: ignore
|
||||
)
|
||||
return main_code
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return code
|
||||
|
||||
|
||||
|
|
@ -187,7 +197,7 @@ def call_method(method, inputs):
|
|||
_method.__globals__["stdout"] = sys.stdout
|
||||
try:
|
||||
return _method()
|
||||
except SystemExit as e:
|
||||
except SystemExit:
|
||||
pass
|
||||
finally:
|
||||
pass
|
||||
|
|
@ -199,7 +209,7 @@ def get_function(compiled_sol, fn_name: str): # type: ignore
|
|||
try:
|
||||
assert hasattr(compiled_sol, fn_name)
|
||||
return getattr(compiled_sol, fn_name)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
|
|
@ -230,13 +240,13 @@ def compile_code(code: str, timeout: int):
|
|||
def convert_line_to_decimals(line: str) -> tuple[bool, list[Decimal]]:
|
||||
try:
|
||||
decimal_line = [Decimal(elem) for elem in line.split()]
|
||||
except:
|
||||
except Exception:
|
||||
return False, []
|
||||
return True, decimal_line
|
||||
|
||||
|
||||
def get_stripped_lines(val: str):
|
||||
## you don't want empty lines to add empty list after splitlines!
|
||||
# you don't want empty lines to add empty list after splitlines!
|
||||
val = val.strip()
|
||||
|
||||
return [val_line.strip() for val_line in val.split("\n")]
|
||||
|
|
@ -349,10 +359,10 @@ def grade_stdio(
|
|||
timeout: int,
|
||||
):
|
||||
print("ORIGINAL CODE\n", code)
|
||||
## runtime doesn't interact well with __name__ == '__main__'
|
||||
# runtime doesn't interact well with __name__ == '__main__'
|
||||
code = clean_if_name(code)
|
||||
|
||||
## we wrap the given code inside another function
|
||||
# we wrap the given code inside another function
|
||||
code = make_function(code)
|
||||
|
||||
compiled_sol = compile_code(code, timeout)
|
||||
|
|
@ -415,8 +425,8 @@ def grade_stdio(
|
|||
else:
|
||||
stripped_gt_out_lines = gt_out
|
||||
|
||||
## WA happens in multiple circumstances
|
||||
## so cache the return to make it clean!
|
||||
# WA happens in multiple circumstances
|
||||
# so cache the return to make it clean!
|
||||
WA_send_args = {
|
||||
"output": truncatefn(prediction),
|
||||
"inputs": truncatefn(gt_inp),
|
||||
|
|
@ -434,18 +444,19 @@ def grade_stdio(
|
|||
stripped_gt_out_line,
|
||||
) in enumerate(zip(stripped_prediction_lines, stripped_gt_out_lines)):
|
||||
WA_send_args["error_message"] = (
|
||||
f"Wrong answer at {output_line_idx=}: {truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}"
|
||||
f"Wrong answer at {output_line_idx=}: "
|
||||
f"{truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}"
|
||||
)
|
||||
|
||||
## CASE 1: exact match
|
||||
# CASE 1: exact match
|
||||
if stripped_prediction_line == stripped_gt_out_line:
|
||||
continue
|
||||
|
||||
## CASE 2: element-wise comparision
|
||||
## if there are floating elements
|
||||
## use `decimal` library for good floating point comparision
|
||||
## otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True
|
||||
## note that we should always be able to convert to decimals
|
||||
# CASE 2: element-wise comparision
|
||||
# if there are floating elements
|
||||
# use `decimal` library for good floating point comparision
|
||||
# otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True
|
||||
# note that we should always be able to convert to decimals
|
||||
|
||||
success, decimal_prediction_line = convert_line_to_decimals(
|
||||
stripped_prediction_line
|
||||
|
|
@ -524,7 +535,6 @@ def run_test(sample, test=None, debug=False, timeout=15):
|
|||
return in_outs, {"error": "no test code provided"}
|
||||
elif test is not None:
|
||||
results = []
|
||||
sol = import_string
|
||||
if debug:
|
||||
print(f"loading test code = {datetime.now().time()}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue