fix lintings

This commit is contained in:
teknium 2025-12-31 21:04:40 +00:00
parent c980157a5a
commit b9dac318ee
2 changed files with 59 additions and 35 deletions

View file

@ -27,10 +27,22 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = ""
system_prompt += f"You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests."
system_prompt += (
"You are an expert Python programmer. You will be given a question "
"(problem specification) and will generate a correct Python program "
"that matches the specification and passes all tests."
)
FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters."
FORMATTING_WITHOUT_STARTER_CODE = "Read the inputs from stdin, solve the problem, and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs it reads the inputs runs the algorithm and writes output to STDOUT."
FORMATTING_MESSAGE_WITH_STARTER_CODE = (
"You will use the following starter code to write the solution to the "
"problem and enclose your code within delimiters."
)
FORMATTING_WITHOUT_STARTER_CODE = (
"Read the inputs from stdin, solve the problem, and write the answer to "
"stdout (do not directly test on the sample inputs). Enclose your code "
"within delimiters as follows. Ensure that when the python program runs "
"it reads the inputs runs the algorithm and writes output to STDOUT."
)
async_semaphore = asyncio.Semaphore(100)
@ -45,8 +57,8 @@ def get_prompt(question, problem_type, starter_code=None):
pass
else:
prompt += f"{FORMATTING_WITHOUT_STARTER_CODE}\n"
prompt += f"```python\n# YOUR CODE HERE\n```\n\n"
prompt += f"### Answer: (use the provided format with backticks)\n\n"
prompt += "```python\n# YOUR CODE HERE\n```\n\n"
prompt += "### Answer: (use the provided format with backticks)\n\n"
return prompt
@ -172,7 +184,7 @@ class CodingEnv(BaseEnv):
chat_completion = await self.server.chat_completion(
messages=[system_msg, user_msg],
n=1, ### CHANGE
n=1, # CHANGE
max_tokens=max_tokens,
temperature=temp,
top_p=top_p,
@ -526,7 +538,9 @@ class CodingEnv(BaseEnv):
old_completion_idx = 0
new_completion_idx = 0
# assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths), f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH"
# assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths),
# f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and
# COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH"
print(
"TEMP METRICS LENGTHS",
len(self.temp_metrics["num_correct"]),
@ -567,9 +581,10 @@ class CodingEnv(BaseEnv):
new_completion_idx += self.config.batch_size
assert old_completion_idx <= len(
self.completion_lengths
), f"OLD COMPLETION IDX {old_completion_idx} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER"
assert old_completion_idx <= len(self.completion_lengths), (
f"OLD COMPLETION IDX {old_completion_idx} and "
f"COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER"
)
wandb_metrics["train/completion_lengths"] = sum(
self.completion_lengths[old_completion_idx:new_completion_idx]
@ -653,7 +668,7 @@ class CodingEnv(BaseEnv):
if self.config.dataset_name == "deepmind":
self.train = load_dataset(
"deepmind/code_contests", split="train"
) ### CHANGE
) # CHANGE
else:
self.train = load_dataset(
"NousResearch/RLVR_Coding_Problems", split="train"
@ -665,7 +680,7 @@ class CodingEnv(BaseEnv):
self.test[-1]["idx"] = len(self.test) - 1
self.test[-1]["split"] = "test"
self.iter = 0 ### CHANGE
self.iter = 0 # CHANGE
self.lock = asyncio.Lock()
self.lock2 = asyncio.Lock()
@ -749,7 +764,6 @@ class CodingEnv(BaseEnv):
scores["scores"] = list()
scores["overrides"] = list()
results = []
codes = []
lengths = []
errors = []
@ -780,7 +794,7 @@ class CodingEnv(BaseEnv):
codes.append(code)
if code == None:
if code is None:
scores["scores"].append(-1.0)
errors.append({"error_message": "No code"})
continue
@ -788,7 +802,7 @@ class CodingEnv(BaseEnv):
try:
async with async_semaphore:
res, metadata = await run_test.remote.aio(test_cases, code)
except modal.exception.RemoteError as e:
except modal.exception.RemoteError:
res = [False]
rprint("index", cur_id)
rprint(test_cases["tests"]["fn_name"])

View file

@ -21,9 +21,19 @@ from types import ModuleType
from unittest.mock import mock_open, patch
import modal
import numpy as np
import_string = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n"
import_string = (
"from string import *\nfrom re import *\nfrom datetime import *\n"
"from collections import *\nfrom heapq import *\nfrom bisect import *\n"
"from copy import *\nfrom math import *\nfrom random import *\n"
"from statistics import *\nfrom itertools import *\nfrom functools import *\n"
"from operator import *\nfrom io import *\nfrom sys import *\n"
"from json import *\nfrom builtins import *\nfrom typing import *\n"
"import string\nimport re\nimport datetime\nimport collections\n"
"import heapq\nimport bisect\nimport copy\nimport math\nimport random\n"
"import statistics\nimport itertools\nimport functools\nimport operator\n"
"import io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n"
)
image = modal.Image.debian_slim(python_version="3.10").pip_install("numpy")
app = modal.App("joeli-lcb", image=image)
@ -121,7 +131,7 @@ def clean_if_name(code: str) -> str:
code = (
ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) # type: ignore
)
except:
except Exception:
pass
return code
@ -155,7 +165,7 @@ def make_function(code: str) -> str:
+ ast.unparse(function_ast) # type: ignore
)
return main_code
except Exception as e:
except Exception:
return code
@ -187,7 +197,7 @@ def call_method(method, inputs):
_method.__globals__["stdout"] = sys.stdout
try:
return _method()
except SystemExit as e:
except SystemExit:
pass
finally:
pass
@ -199,7 +209,7 @@ def get_function(compiled_sol, fn_name: str): # type: ignore
try:
assert hasattr(compiled_sol, fn_name)
return getattr(compiled_sol, fn_name)
except Exception as e:
except Exception:
return
@ -230,13 +240,13 @@ def compile_code(code: str, timeout: int):
def convert_line_to_decimals(line: str) -> tuple[bool, list[Decimal]]:
try:
decimal_line = [Decimal(elem) for elem in line.split()]
except:
except Exception:
return False, []
return True, decimal_line
def get_stripped_lines(val: str):
## you don't want empty lines to add empty list after splitlines!
# you don't want empty lines to add empty list after splitlines!
val = val.strip()
return [val_line.strip() for val_line in val.split("\n")]
@ -349,10 +359,10 @@ def grade_stdio(
timeout: int,
):
print("ORIGINAL CODE\n", code)
## runtime doesn't interact well with __name__ == '__main__'
# runtime doesn't interact well with __name__ == '__main__'
code = clean_if_name(code)
## we wrap the given code inside another function
# we wrap the given code inside another function
code = make_function(code)
compiled_sol = compile_code(code, timeout)
@ -415,8 +425,8 @@ def grade_stdio(
else:
stripped_gt_out_lines = gt_out
## WA happens in multiple circumstances
## so cache the return to make it clean!
# WA happens in multiple circumstances
# so cache the return to make it clean!
WA_send_args = {
"output": truncatefn(prediction),
"inputs": truncatefn(gt_inp),
@ -434,18 +444,19 @@ def grade_stdio(
stripped_gt_out_line,
) in enumerate(zip(stripped_prediction_lines, stripped_gt_out_lines)):
WA_send_args["error_message"] = (
f"Wrong answer at {output_line_idx=}: {truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}"
f"Wrong answer at {output_line_idx=}: "
f"{truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}"
)
## CASE 1: exact match
# CASE 1: exact match
if stripped_prediction_line == stripped_gt_out_line:
continue
## CASE 2: element-wise comparision
## if there are floating elements
## use `decimal` library for good floating point comparision
## otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True
## note that we should always be able to convert to decimals
# CASE 2: element-wise comparision
# if there are floating elements
# use `decimal` library for good floating point comparision
# otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True
# note that we should always be able to convert to decimals
success, decimal_prediction_line = convert_line_to_decimals(
stripped_prediction_line
@ -524,7 +535,6 @@ def run_test(sample, test=None, debug=False, timeout=15):
return in_outs, {"error": "no test code provided"}
elif test is not None:
results = []
sol = import_string
if debug:
print(f"loading test code = {datetime.now().time()}")