mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
87 lines
No EOL
3.2 KiB
Python
Executable file
87 lines
No EOL
3.2 KiB
Python
Executable file
from internbootcamp.bootcamp.base import Basebootcamp
|
|
import re
|
|
import tempfile
|
|
import subprocess
|
|
import os
|
|
from datasets import load_dataset
|
|
|
|
class KodCodebootcamp(Basebootcamp):
|
|
def __init__(self, **params):
|
|
super().__init__(**params)
|
|
self.dataset = load_dataset("PRIME-Scaling/Kodcode", split="train")
|
|
self.item_id = 0
|
|
|
|
# no case generator for instruction following
|
|
def case_generator(self):
|
|
source_case = self.dataset[self.item_id]
|
|
case = {
|
|
'prompt': source_case['prompt'][-1]['content'],
|
|
'code_prompt': source_case['reward_model']['ground_truth']['code_prompt'],
|
|
'test': source_case['reward_model']['ground_truth']['test'],
|
|
}
|
|
self.item_id += 1
|
|
return case
|
|
|
|
@staticmethod
|
|
def prompt_func(case) -> str:
|
|
return case['prompt'] + "\n" + case['code_prompt']
|
|
|
|
@staticmethod
|
|
def extract_output(output):
|
|
match = re.search(r'```python(.*?)```', output, re.DOTALL)
|
|
if match:
|
|
solution = match.group(1).strip()
|
|
# print(f"solution: {solution}")
|
|
return solution
|
|
else:
|
|
return None
|
|
|
|
@classmethod
|
|
def _verify_correction(cls, solution, identity) -> bool:
|
|
# task = identity['task']
|
|
|
|
test_code = identity['test']
|
|
# print(f"test_code: {test_code}")
|
|
|
|
timeout = 30
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# 1. 生成 task_func 代码文件
|
|
task_file = os.path.join(temp_dir, "solution.py")
|
|
with open(task_file, "w", encoding="utf-8") as f:
|
|
f.write(solution)
|
|
|
|
# 2. 生成 unittest 测试文件
|
|
# test_code = "from task import task_func\n" + test_code
|
|
test_file = os.path.join(temp_dir, "test_task.py")
|
|
with open(test_file, "w", encoding="utf-8") as f:
|
|
f.write(test_code)
|
|
|
|
# 3. 运行 unittest
|
|
process = subprocess.Popen(
|
|
["python", "-m", "unittest", "test_task.py"],
|
|
cwd=temp_dir,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
text=True
|
|
)
|
|
try:
|
|
stdout, stderr = process.communicate(timeout=timeout)
|
|
except subprocess.TimeoutExpired:
|
|
process.kill()
|
|
process.wait()
|
|
return 0 # 进程超时,返回 0 分
|
|
process.stdout.close() # 关闭管道,释放资源
|
|
process.stderr.close()
|
|
|
|
# 4. 解析测试结果
|
|
match_ran = re.search(r'Ran (\d+) tests?', stderr)
|
|
match_failed = re.search(r'failures=(\d+)', stderr)
|
|
match_error = re.search(r'errors=(\d+)', stderr)
|
|
|
|
total_tests = int(match_ran.group(1)) if match_ran else 0
|
|
failed_tests = int(match_failed.group(1)) if match_failed else 0
|
|
error_tests = int(match_error.group(1)) if match_error else 0
|
|
passed_tests = total_tests - failed_tests - error_tests
|
|
pass_rate = float(passed_tests) / float(total_tests) if total_tests > 0 else 0.0
|
|
|
|
return pass_rate |