mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix lintings
This commit is contained in:
parent
c980157a5a
commit
b9dac318ee
2 changed files with 59 additions and 35 deletions
|
|
@ -27,10 +27,22 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|||
|
||||
system_prompt = ""
|
||||
|
||||
system_prompt += f"You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests."
|
||||
system_prompt += (
|
||||
"You are an expert Python programmer. You will be given a question "
|
||||
"(problem specification) and will generate a correct Python program "
|
||||
"that matches the specification and passes all tests."
|
||||
)
|
||||
|
||||
FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters."
|
||||
FORMATTING_WITHOUT_STARTER_CODE = "Read the inputs from stdin, solve the problem, and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs it reads the inputs runs the algorithm and writes output to STDOUT."
|
||||
FORMATTING_MESSAGE_WITH_STARTER_CODE = (
|
||||
"You will use the following starter code to write the solution to the "
|
||||
"problem and enclose your code within delimiters."
|
||||
)
|
||||
FORMATTING_WITHOUT_STARTER_CODE = (
|
||||
"Read the inputs from stdin, solve the problem, and write the answer to "
|
||||
"stdout (do not directly test on the sample inputs). Enclose your code "
|
||||
"within delimiters as follows. Ensure that when the python program runs "
|
||||
"it reads the inputs runs the algorithm and writes output to STDOUT."
|
||||
)
|
||||
|
||||
async_semaphore = asyncio.Semaphore(100)
|
||||
|
||||
|
|
@ -45,8 +57,8 @@ def get_prompt(question, problem_type, starter_code=None):
|
|||
pass
|
||||
else:
|
||||
prompt += f"{FORMATTING_WITHOUT_STARTER_CODE}\n"
|
||||
prompt += f"```python\n# YOUR CODE HERE\n```\n\n"
|
||||
prompt += f"### Answer: (use the provided format with backticks)\n\n"
|
||||
prompt += "```python\n# YOUR CODE HERE\n```\n\n"
|
||||
prompt += "### Answer: (use the provided format with backticks)\n\n"
|
||||
return prompt
|
||||
|
||||
|
||||
|
|
@ -172,7 +184,7 @@ class CodingEnv(BaseEnv):
|
|||
|
||||
chat_completion = await self.server.chat_completion(
|
||||
messages=[system_msg, user_msg],
|
||||
n=1, ### CHANGE
|
||||
n=1, # CHANGE
|
||||
max_tokens=max_tokens,
|
||||
temperature=temp,
|
||||
top_p=top_p,
|
||||
|
|
@ -526,7 +538,9 @@ class CodingEnv(BaseEnv):
|
|||
old_completion_idx = 0
|
||||
new_completion_idx = 0
|
||||
|
||||
# assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths), f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH"
|
||||
# assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths),
|
||||
# f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and
|
||||
# COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH"
|
||||
print(
|
||||
"TEMP METRICS LENGTHS",
|
||||
len(self.temp_metrics["num_correct"]),
|
||||
|
|
@ -567,9 +581,10 @@ class CodingEnv(BaseEnv):
|
|||
|
||||
new_completion_idx += self.config.batch_size
|
||||
|
||||
assert old_completion_idx <= len(
|
||||
self.completion_lengths
|
||||
), f"OLD COMPLETION IDX {old_completion_idx} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER"
|
||||
assert old_completion_idx <= len(self.completion_lengths), (
|
||||
f"OLD COMPLETION IDX {old_completion_idx} and "
|
||||
f"COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER"
|
||||
)
|
||||
|
||||
wandb_metrics["train/completion_lengths"] = sum(
|
||||
self.completion_lengths[old_completion_idx:new_completion_idx]
|
||||
|
|
@ -653,7 +668,7 @@ class CodingEnv(BaseEnv):
|
|||
if self.config.dataset_name == "deepmind":
|
||||
self.train = load_dataset(
|
||||
"deepmind/code_contests", split="train"
|
||||
) ### CHANGE
|
||||
) # CHANGE
|
||||
else:
|
||||
self.train = load_dataset(
|
||||
"NousResearch/RLVR_Coding_Problems", split="train"
|
||||
|
|
@ -665,7 +680,7 @@ class CodingEnv(BaseEnv):
|
|||
self.test[-1]["idx"] = len(self.test) - 1
|
||||
self.test[-1]["split"] = "test"
|
||||
|
||||
self.iter = 0 ### CHANGE
|
||||
self.iter = 0 # CHANGE
|
||||
self.lock = asyncio.Lock()
|
||||
self.lock2 = asyncio.Lock()
|
||||
|
||||
|
|
@ -749,7 +764,6 @@ class CodingEnv(BaseEnv):
|
|||
scores["scores"] = list()
|
||||
scores["overrides"] = list()
|
||||
|
||||
results = []
|
||||
codes = []
|
||||
lengths = []
|
||||
errors = []
|
||||
|
|
@ -780,7 +794,7 @@ class CodingEnv(BaseEnv):
|
|||
|
||||
codes.append(code)
|
||||
|
||||
if code == None:
|
||||
if code is None:
|
||||
scores["scores"].append(-1.0)
|
||||
errors.append({"error_message": "No code"})
|
||||
continue
|
||||
|
|
@ -788,7 +802,7 @@ class CodingEnv(BaseEnv):
|
|||
try:
|
||||
async with async_semaphore:
|
||||
res, metadata = await run_test.remote.aio(test_cases, code)
|
||||
except modal.exception.RemoteError as e:
|
||||
except modal.exception.RemoteError:
|
||||
res = [False]
|
||||
rprint("index", cur_id)
|
||||
rprint(test_cases["tests"]["fn_name"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue