diff --git a/environments/code_execution_server/coding_server.py b/environments/code_execution_server/coding_server.py index 702b10cd..43581a07 100644 --- a/environments/code_execution_server/coding_server.py +++ b/environments/code_execution_server/coding_server.py @@ -27,10 +27,22 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer system_prompt = "" -system_prompt += f"You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests." +system_prompt += ( + "You are an expert Python programmer. You will be given a question " + "(problem specification) and will generate a correct Python program " + "that matches the specification and passes all tests." +) -FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters." -FORMATTING_WITHOUT_STARTER_CODE = "Read the inputs from stdin, solve the problem, and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs it reads the inputs runs the algorithm and writes output to STDOUT." +FORMATTING_MESSAGE_WITH_STARTER_CODE = ( + "You will use the following starter code to write the solution to the " + "problem and enclose your code within delimiters." +) +FORMATTING_WITHOUT_STARTER_CODE = ( + "Read the inputs from stdin, solve the problem, and write the answer to " + "stdout (do not directly test on the sample inputs). Enclose your code " + "within delimiters as follows. Ensure that when the python program runs " + "it reads the inputs runs the algorithm and writes output to STDOUT." +) async_semaphore = asyncio.Semaphore(100) @@ -45,8 +57,8 @@ def get_prompt(question, problem_type, starter_code=None): pass else: prompt += f"{FORMATTING_WITHOUT_STARTER_CODE}\n" - prompt += f"```python\n# YOUR CODE HERE\n```\n\n" - prompt += f"### Answer: (use the provided format with backticks)\n\n" + prompt += "```python\n# YOUR CODE HERE\n```\n\n" + prompt += "### Answer: (use the provided format with backticks)\n\n" return prompt @@ -172,7 +184,7 @@ class CodingEnv(BaseEnv): chat_completion = await self.server.chat_completion( messages=[system_msg, user_msg], - n=1, ### CHANGE + n=1, # CHANGE max_tokens=max_tokens, temperature=temp, top_p=top_p, @@ -526,7 +538,9 @@ class CodingEnv(BaseEnv): old_completion_idx = 0 new_completion_idx = 0 - # assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths), f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH" + # assert len(self.temp_metrics["num_correct"]) == len(self.completion_lengths), + # f"TEMP METRICS {len(self.temp_metrics['num_correct'])} and + # COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SAME LENGTH" print( "TEMP METRICS LENGTHS", len(self.temp_metrics["num_correct"]), @@ -567,9 +581,10 @@ class CodingEnv(BaseEnv): new_completion_idx += self.config.batch_size - assert old_completion_idx <= len( - self.completion_lengths - ), f"OLD COMPLETION IDX {old_completion_idx} and COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER" + assert old_completion_idx <= len(self.completion_lengths), ( + f"OLD COMPLETION IDX {old_completion_idx} and " + f"COMPLETION LENGTHS {len(self.completion_lengths)} MUST BE SMALLER" + ) wandb_metrics["train/completion_lengths"] = sum( self.completion_lengths[old_completion_idx:new_completion_idx] @@ -653,7 +668,7 @@ class CodingEnv(BaseEnv): if self.config.dataset_name == "deepmind": self.train = load_dataset( "deepmind/code_contests", split="train" - ) ### CHANGE + ) # CHANGE else: self.train = load_dataset( "NousResearch/RLVR_Coding_Problems", split="train" @@ -665,7 +680,7 @@ class CodingEnv(BaseEnv): self.test[-1]["idx"] = len(self.test) - 1 self.test[-1]["split"] = "test" - self.iter = 0 ### CHANGE + self.iter = 0 # CHANGE self.lock = asyncio.Lock() self.lock2 = asyncio.Lock() @@ -749,7 +764,6 @@ class CodingEnv(BaseEnv): scores["scores"] = list() scores["overrides"] = list() - results = [] codes = [] lengths = [] errors = [] @@ -780,7 +794,7 @@ class CodingEnv(BaseEnv): codes.append(code) - if code == None: + if code is None: scores["scores"].append(-1.0) errors.append({"error_message": "No code"}) continue @@ -788,7 +802,7 @@ class CodingEnv(BaseEnv): try: async with async_semaphore: res, metadata = await run_test.remote.aio(test_cases, code) - except modal.exception.RemoteError as e: + except modal.exception.RemoteError: res = [False] rprint("index", cur_id) rprint(test_cases["tests"]["fn_name"]) diff --git a/environments/code_execution_server/lcb_modal_endpoint.py b/environments/code_execution_server/lcb_modal_endpoint.py index 49b111c0..a8809724 100644 --- a/environments/code_execution_server/lcb_modal_endpoint.py +++ b/environments/code_execution_server/lcb_modal_endpoint.py @@ -21,9 +21,19 @@ from types import ModuleType from unittest.mock import mock_open, patch import modal -import numpy as np -import_string = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n" +import_string = ( + "from string import *\nfrom re import *\nfrom datetime import *\n" + "from collections import *\nfrom heapq import *\nfrom bisect import *\n" + "from copy import *\nfrom math import *\nfrom random import *\n" + "from statistics import *\nfrom itertools import *\nfrom functools import *\n" + "from operator import *\nfrom io import *\nfrom sys import *\n" + "from json import *\nfrom builtins import *\nfrom typing import *\n" + "import string\nimport re\nimport datetime\nimport collections\n" + "import heapq\nimport bisect\nimport copy\nimport math\nimport random\n" + "import statistics\nimport itertools\nimport functools\nimport operator\n" + "import io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n" +) image = modal.Image.debian_slim(python_version="3.10").pip_install("numpy") app = modal.App("joeli-lcb", image=image) @@ -121,7 +131,7 @@ def clean_if_name(code: str) -> str: code = ( ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) # type: ignore ) - except: + except Exception: pass return code @@ -155,7 +165,7 @@ def make_function(code: str) -> str: + ast.unparse(function_ast) # type: ignore ) return main_code - except Exception as e: + except Exception: return code @@ -187,7 +197,7 @@ def call_method(method, inputs): _method.__globals__["stdout"] = sys.stdout try: return _method() - except SystemExit as e: + except SystemExit: pass finally: pass @@ -199,7 +209,7 @@ def get_function(compiled_sol, fn_name: str): # type: ignore try: assert hasattr(compiled_sol, fn_name) return getattr(compiled_sol, fn_name) - except Exception as e: + except Exception: return @@ -230,13 +240,13 @@ def compile_code(code: str, timeout: int): def convert_line_to_decimals(line: str) -> tuple[bool, list[Decimal]]: try: decimal_line = [Decimal(elem) for elem in line.split()] - except: + except Exception: return False, [] return True, decimal_line def get_stripped_lines(val: str): - ## you don't want empty lines to add empty list after splitlines! + # you don't want empty lines to add empty list after splitlines! val = val.strip() return [val_line.strip() for val_line in val.split("\n")] @@ -349,10 +359,10 @@ def grade_stdio( timeout: int, ): print("ORIGINAL CODE\n", code) - ## runtime doesn't interact well with __name__ == '__main__' + # runtime doesn't interact well with __name__ == '__main__' code = clean_if_name(code) - ## we wrap the given code inside another function + # we wrap the given code inside another function code = make_function(code) compiled_sol = compile_code(code, timeout) @@ -415,8 +425,8 @@ def grade_stdio( else: stripped_gt_out_lines = gt_out - ## WA happens in multiple circumstances - ## so cache the return to make it clean! + # WA happens in multiple circumstances + # so cache the return to make it clean! WA_send_args = { "output": truncatefn(prediction), "inputs": truncatefn(gt_inp), @@ -434,18 +444,19 @@ def grade_stdio( stripped_gt_out_line, ) in enumerate(zip(stripped_prediction_lines, stripped_gt_out_lines)): WA_send_args["error_message"] = ( - f"Wrong answer at {output_line_idx=}: {truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}" + f"Wrong answer at {output_line_idx=}: " + f"{truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}" ) - ## CASE 1: exact match + # CASE 1: exact match if stripped_prediction_line == stripped_gt_out_line: continue - ## CASE 2: element-wise comparision - ## if there are floating elements - ## use `decimal` library for good floating point comparision - ## otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True - ## note that we should always be able to convert to decimals + # CASE 2: element-wise comparision + # if there are floating elements + # use `decimal` library for good floating point comparision + # otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True + # note that we should always be able to convert to decimals success, decimal_prediction_line = convert_line_to_decimals( stripped_prediction_line @@ -524,7 +535,6 @@ def run_test(sample, test=None, debug=False, timeout=15): return in_outs, {"error": "no test code provided"} elif test is not None: results = [] - sol = import_string if debug: print(f"loading test code = {datetime.now().time()}")