InternBootcamp/internbootcamp/bootcamp/kodcode/kodcode.py
2025-05-23 15:27:15 +08:00

87 lines
No EOL
3.2 KiB
Python

from internbootcamp.bootcamp.base import Basebootcamp
import re
import tempfile
import subprocess
import os
from datasets import load_dataset
class KodCodebootcamp(Basebootcamp):
def __init__(self, **params):
super().__init__(**params)
self.dataset = load_dataset("PRIME-Scaling/Kodcode", split="train")
self.item_id = 0
# no case generator for instruction following
def case_generator(self):
source_case = self.dataset[self.item_id]
case = {
'prompt': source_case['prompt'][-1]['content'],
'code_prompt': source_case['reward_model']['ground_truth']['code_prompt'],
'test': source_case['reward_model']['ground_truth']['test'],
}
self.item_id += 1
return case
@staticmethod
def prompt_func(case) -> str:
return case['prompt'] + "\n" + case['code_prompt']
@staticmethod
def extract_output(output):
match = re.search(r'```python(.*?)```', output, re.DOTALL)
if match:
solution = match.group(1).strip()
# print(f"solution: {solution}")
return solution
else:
return None
@classmethod
def _verify_correction(cls, solution, identity) -> bool:
# task = identity['task']
test_code = identity['test']
# print(f"test_code: {test_code}")
timeout = 30
with tempfile.TemporaryDirectory() as temp_dir:
# 1. 生成 task_func 代码文件
task_file = os.path.join(temp_dir, "solution.py")
with open(task_file, "w", encoding="utf-8") as f:
f.write(solution)
# 2. 生成 unittest 测试文件
# test_code = "from task import task_func\n" + test_code
test_file = os.path.join(temp_dir, "test_task.py")
with open(test_file, "w", encoding="utf-8") as f:
f.write(test_code)
# 3. 运行 unittest
process = subprocess.Popen(
["python", "-m", "unittest", "test_task.py"],
cwd=temp_dir,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
try:
stdout, stderr = process.communicate(timeout=timeout)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
return 0 # 进程超时,返回 0 分
process.stdout.close() # 关闭管道,释放资源
process.stderr.close()
# 4. 解析测试结果
match_ran = re.search(r'Ran (\d+) tests?', stderr)
match_failed = re.search(r'failures=(\d+)', stderr)
match_error = re.search(r'errors=(\d+)', stderr)
total_tests = int(match_ran.group(1)) if match_ran else 0
failed_tests = int(match_failed.group(1)) if match_failed else 0
error_tests = int(match_error.group(1)) if match_error else 0
passed_tests = total_tests - failed_tests - error_tests
pass_rate = float(passed_tests) / float(total_tests) if total_tests > 0 else 0.0
return pass_rate