fix lintings

This commit is contained in:
teknium 2025-12-31 21:04:40 +00:00
parent c980157a5a
commit b9dac318ee
2 changed files with 59 additions and 35 deletions

View file

@ -27,10 +27,22 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
system_prompt = ""
system_prompt += f"You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests."
system_prompt += (
"You are an expert Python programmer. You will be given a question "
"(problem specification) and will generate a correct Python program "
"that matches the specification and passes all tests."
)
FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters."
FORMATTING_WITHOUT_STARTER_CODE = "Read the inputs from stdin, solve the problem, and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs it reads the inputs runs the algorithm and writes output to STDOUT."
FORMATTING_MESSAGE_WITH_STARTER_CODE = (
"You will use the following starter code to write the solution to the "
"problem and enclose your code within delimiters."
)
FORMATTING_WITHOUT_STARTER_CODE = (
"Read the inputs from stdin, solve the problem, and write the answer to "
"stdout (do not directly test on the sample inputs). Enclose your code "
"within delimiters as follows. Ensure that when the python program runs "
"it reads the inputs runs the algorithm and writes output to STDOUT."
)
async_semaphore = asyncio.Semaphore(100)
@ -45,8 +57,8 @@ def get_prompt(question, problem_type, starter_code=None):
pass
else:
prompt += f"{FORMATTING_WITHOUT_STARTER_CODE}\n"
prompt += f"```python\n# YOUR CODE HERE\n```\n\n"
prompt += f"### Answer: (use the provided format with backticks)\n\n"
prompt += "```python\n# YOUR CODE HERE\n```\n\n"
prompt += "### Answer: (use the provided format with backticks)\n\n"
return prompt
@ -172,7 +184,7 @@ class CodingEnv(BaseEnv):
chat_completion = await self.server.chat_completion(
messages=[system_msg, user_msg],
n=1, ### CHANGE
n=1, # CHANGE
max_tokens=max_tokens,
temperature=temp,
top_p=top_p,
@ -526,7 +538,9 @@ class CodingEnv(BaseEnv):
old_completion_idx = 0
new_completion_idx = 0
# assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths), f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH"
# assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths),
# f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and
# COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH"
print(
"TEMP METRICS LENGTHS",
len(self.temp_metrics["num_correct"]),
@ -567,9 +581,10 @@ class CodingEnv(BaseEnv):
new_completion_idx += self.config.batch_size
assert old_completion_idx <= len(
self.completion_lengths
), f"OLD COMPLETION IDX {old_completion_idx} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER"
assert old_completion_idx <= len(self.completion_lengths), (
f"OLD COMPLETION IDX {old_completion_idx} and "
f"COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER"
)
wandb_metrics["train/completion_lengths"] = sum(
self.completion_lengths[old_completion_idx:new_completion_idx]
@ -653,7 +668,7 @@ class CodingEnv(BaseEnv):
if self.config.dataset_name == "deepmind":
self.train = load_dataset(
"deepmind/code_contests", split="train"
) ### CHANGE
) # CHANGE
else:
self.train = load_dataset(
"NousResearch/RLVR_Coding_Problems", split="train"
@ -665,7 +680,7 @@ class CodingEnv(BaseEnv):
self.test[-1]["idx"] = len(self.test) - 1
self.test[-1]["split"] = "test"
self.iter = 0 ### CHANGE
self.iter = 0 # CHANGE
self.lock = asyncio.Lock()
self.lock2 = asyncio.Lock()
@ -749,7 +764,6 @@ class CodingEnv(BaseEnv):
scores["scores"] = list()
scores["overrides"] = list()
results = []
codes = []
lengths = []
errors = []
@ -780,7 +794,7 @@ class CodingEnv(BaseEnv):
codes.append(code)
if code == None:
if code is None:
scores["scores"].append(-1.0)
errors.append({"error_message": "No code"})
continue
@ -788,7 +802,7 @@ class CodingEnv(BaseEnv):
try:
async with async_semaphore:
res, metadata = await run_test.remote.aio(test_cases, code)
except modal.exception.RemoteError as e:
except modal.exception.RemoteError:
res = [False]
rprint("index", cur_id)
rprint(test_cases["tests"]["fn_name"])