fix lintings

This commit is contained in:
teknium 2025-12-31 21:04:40 +00:00
parent c980157a5a
commit b9dac318ee
2 changed files with 59 additions and 35 deletions

View file

@ -21,9 +21,19 @@ from types import ModuleType
from unittest.mock import mock_open, patch
import modal
import numpy as np
import_string = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n"
import_string = (
"from string import *\nfrom re import *\nfrom datetime import *\n"
"from collections import *\nfrom heapq import *\nfrom bisect import *\n"
"from copy import *\nfrom math import *\nfrom random import *\n"
"from statistics import *\nfrom itertools import *\nfrom functools import *\n"
"from operator import *\nfrom io import *\nfrom sys import *\n"
"from json import *\nfrom builtins import *\nfrom typing import *\n"
"import string\nimport re\nimport datetime\nimport collections\n"
"import heapq\nimport bisect\nimport copy\nimport math\nimport random\n"
"import statistics\nimport itertools\nimport functools\nimport operator\n"
"import io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n"
)
image = modal.Image.debian_slim(python_version="3.10").pip_install("numpy")
app = modal.App("joeli-lcb", image=image)
@ -121,7 +131,7 @@ def clean_if_name(code: str) -> str:
code = (
ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) # type: ignore
)
except:
except Exception:
pass
return code
@ -155,7 +165,7 @@ def make_function(code: str) -> str:
+ ast.unparse(function_ast) # type: ignore
)
return main_code
except Exception as e:
except Exception:
return code
@ -187,7 +197,7 @@ def call_method(method, inputs):
_method.__globals__["stdout"] = sys.stdout
try:
return _method()
except SystemExit as e:
except SystemExit:
pass
finally:
pass
@ -199,7 +209,7 @@ def get_function(compiled_sol, fn_name: str): # type: ignore
try:
assert hasattr(compiled_sol, fn_name)
return getattr(compiled_sol, fn_name)
except Exception as e:
except Exception:
return
@ -230,13 +240,13 @@ def compile_code(code: str, timeout: int):
def convert_line_to_decimals(line: str) -> tuple[bool, list[Decimal]]:
try:
decimal_line = [Decimal(elem) for elem in line.split()]
except:
except Exception:
return False, []
return True, decimal_line
def get_stripped_lines(val: str):
## you don't want empty lines to add empty list after splitlines!
# you don't want empty lines to add empty list after splitlines!
val = val.strip()
return [val_line.strip() for val_line in val.split("\n")]
@ -349,10 +359,10 @@ def grade_stdio(
timeout: int,
):
print("ORIGINAL CODE\n", code)
## runtime doesn't interact well with __name__ == '__main__'
# runtime doesn't interact well with __name__ == '__main__'
code = clean_if_name(code)
## we wrap the given code inside another function
# we wrap the given code inside another function
code = make_function(code)
compiled_sol = compile_code(code, timeout)
@ -415,8 +425,8 @@ def grade_stdio(
else:
stripped_gt_out_lines = gt_out
## WA happens in multiple circumstances
## so cache the return to make it clean!
# WA happens in multiple circumstances
# so cache the return to make it clean!
WA_send_args = {
"output": truncatefn(prediction),
"inputs": truncatefn(gt_inp),
@ -434,18 +444,19 @@ def grade_stdio(
stripped_gt_out_line,
) in enumerate(zip(stripped_prediction_lines, stripped_gt_out_lines)):
WA_send_args["error_message"] = (
f"Wrong answer at {output_line_idx=}: {truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}"
f"Wrong answer at {output_line_idx=}: "
f"{truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}"
)
## CASE 1: exact match
# CASE 1: exact match
if stripped_prediction_line == stripped_gt_out_line:
continue
## CASE 2: element-wise comparision
## if there are floating elements
## use `decimal` library for good floating point comparision
## otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True
## note that we should always be able to convert to decimals
# CASE 2: element-wise comparision
# if there are floating elements
# use `decimal` library for good floating point comparision
# otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True
# note that we should always be able to convert to decimals
success, decimal_prediction_line = convert_line_to_decimals(
stripped_prediction_line
@ -524,7 +535,6 @@ def run_test(sample, test=None, debug=False, timeout=15):
return in_outs, {"error": "no test code provided"}
elif test is not None:
results = []
sol = import_string
if debug:
print(f"loading test code = {datetime.now().time()}")