mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
feat: add code_debug community environment
This commit is contained in:
parent
c421582b6f
commit
590e8a1ef2
4 changed files with 930 additions and 0 deletions
81
environments/community/code_debug_env/README.md
Normal file
81
environments/community/code_debug_env/README.md
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# Code Debug Environment
|
||||
|
||||
An Atropos RL environment for training LLMs to debug and fix buggy Python code.
|
||||
|
||||
## Overview
|
||||
|
||||
This environment uses the [HumanEvalFix](https://huggingface.co/datasets/bigcode/humanevalfix-python) dataset, which contains 164 buggy Python functions with associated test suites. The model receives a buggy function and must output the corrected version inside `\boxed{}`. Scoring is done by executing the fixed code against the original test cases.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
code_debug_env.py # Main env (extends BaseEnv)
|
||||
code_executor.py # Safe subprocess execution with timeout
|
||||
test_code_debug.py # Unit tests
|
||||
README.md # This file
|
||||
```
|
||||
|
||||
## Reward Design
|
||||
|
||||
| Outcome | Score | Description |
|
||||
|---------|-------|-------------|
|
||||
| All tests pass | **1.0** | Perfect fix |
|
||||
| Partial improvement | **-0.5 to 0.9** | More tests pass than buggy version |
|
||||
| No improvement | **-0.5** | Code runs but doesn't fix anything |
|
||||
| Compilation error / regression | **-1.0** | Fix is worse than the original |
|
||||
|
||||
When all rollouts in a group score 1.0, a **length penalty** is applied to encourage concise solutions (same pattern as `sql_query_env`).
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install dependencies (datasets is the only extra)
|
||||
pip install datasets
|
||||
|
||||
# Run tests
|
||||
cd environments/community/code_debug_env
|
||||
python -m pytest test_code_debug.py -v
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Process mode (offline data generation)
|
||||
python code_debug_env.py process \
|
||||
--env.data_path_to_save_groups data/code_debug.jsonl \
|
||||
--env.group_size 8 \
|
||||
--openai.base_url http://localhost:8000/v1 \
|
||||
--openai.model_name "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
|
||||
# Serve mode (online RL training)
|
||||
python code_debug_env.py serve \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.model_name "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
|
||||
# Evaluate mode
|
||||
python code_debug_env.py evaluate \
|
||||
--openai.base_url http://localhost:8000/v1 \
|
||||
--openai.model_name "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
```
|
||||
|
||||
## WandB Metrics
|
||||
|
||||
| Metric | Description |
|
||||
|--------|-------------|
|
||||
| `train/percent_correct` | Fraction of rollouts that pass all tests |
|
||||
| `train/avg_score` | Average reward across rollouts |
|
||||
| `train/partial_fix_rate` | Fraction of rollouts that partially fix the code |
|
||||
| `eval/percent_correct` | Eval set accuracy |
|
||||
|
||||
## Dataset
|
||||
|
||||
- **Source**: [bigcode/humanevalfix-python](https://huggingface.co/datasets/bigcode/humanevalfix-python)
|
||||
- **License**: Apache 2.0
|
||||
- **Size**: 164 problems
|
||||
- **Split**: 80% train / 20% test
|
||||
|
||||
## Compute Footprint
|
||||
|
||||
- **RAM**: < 1 GB (dataset is small, execution is in subprocess)
|
||||
- **CPU**: < 5s per verification (subprocess with 10s timeout)
|
||||
- **GPU**: Only needed for the inference server
|
||||
457
environments/community/code_debug_env/code_debug_env.py
Normal file
457
environments/community/code_debug_env/code_debug_env.py
Normal file
|
|
@ -0,0 +1,457 @@
|
|||
"""
|
||||
Code Debug Environment for Atropos
|
||||
|
||||
Trains LLMs to debug and fix buggy Python functions.
|
||||
Uses the HumanEvalFix dataset with execution-based verification
|
||||
against ground-truth test cases.
|
||||
|
||||
Environment pattern follows sql_query_env for consistency.
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from code_executor import (
|
||||
count_test_results,
|
||||
execute_code_with_tests,
|
||||
extract_boxed_code,
|
||||
)
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
# System prompt following established Atropos patterns
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought "
|
||||
"to deeply consider the problem and deliberate with yourself via systematic "
|
||||
"reasoning processes to help come to a correct solution prior to answering. "
|
||||
"You should enclose your thoughts and internal monologue inside <think> </think> "
|
||||
"tags, and then provide your solution or response to the problem.\n\n"
|
||||
)
|
||||
|
||||
SYSTEM_PROMPT += """You are an expert Python debugger. Given a buggy Python function \
|
||||
and its test cases, identify the bug and provide the corrected function.
|
||||
|
||||
You are allocated a maximum of 2048 tokens, please strive to use less.
|
||||
|
||||
Provide your corrected function inside \\boxed{} like this:
|
||||
\\boxed{def function_name(args):
|
||||
# your corrected code here
|
||||
return result
|
||||
}
|
||||
|
||||
Important:
|
||||
- Keep the same function signature
|
||||
- Only fix the bug, don't rewrite the function from scratch unless necessary
|
||||
- Ensure the function passes all provided test cases
|
||||
- Do not include test cases or the check function in your answer
|
||||
|
||||
End your answer with \\boxed{your corrected function here}"""
|
||||
|
||||
|
||||
class CodeDebugItem(TypedDict):
|
||||
"""Type definition for a HumanEvalFix dataset item."""
|
||||
|
||||
task_id: str
|
||||
prompt: str
|
||||
buggy_solution: str
|
||||
canonical_solution: str
|
||||
test: str
|
||||
entry_point: str
|
||||
|
||||
|
||||
def format_debug_prompt(item: CodeDebugItem) -> str:
|
||||
"""Format the buggy code and context into a prompt for the model."""
|
||||
buggy_code = item["prompt"] + item["buggy_solution"]
|
||||
# Show test structure without revealing the exact assertions
|
||||
return (
|
||||
f"Here is a buggy Python function:\n\n"
|
||||
f"```python\n{buggy_code}```\n\n"
|
||||
f"The function `{item['entry_point']}` has a bug. "
|
||||
f"It fails its test cases.\n\n"
|
||||
f"Please identify the bug, fix it, and provide the corrected "
|
||||
f"function inside \\boxed{{}}."
|
||||
)
|
||||
|
||||
|
||||
class CodeDebugEnv(BaseEnv):
|
||||
"""
|
||||
Environment for training LLMs to debug Python code.
|
||||
|
||||
Uses the HumanEvalFix dataset. The model receives a buggy function
|
||||
and must output the corrected version. Scoring is done by executing
|
||||
the fixed code against the original test suite.
|
||||
"""
|
||||
|
||||
name = "code_debug"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
self.partial_fix_buffer = list()
|
||||
self.raw_score_buffer = list()
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
"""Initialize default configuration for the environment."""
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="code_debug",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log custom metrics to WandB."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Log percent fully correct (all tests pass)
|
||||
if self.percent_correct_buffer:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
|
||||
# Log average raw score (includes partial credit)
|
||||
if self.raw_score_buffer:
|
||||
wandb_metrics["train/avg_score"] = sum(self.raw_score_buffer) / len(
|
||||
self.raw_score_buffer
|
||||
)
|
||||
|
||||
# Log partial fix rate (code runs but doesn't pass all tests)
|
||||
if self.partial_fix_buffer:
|
||||
wandb_metrics["train/partial_fix_rate"] = sum(
|
||||
self.partial_fix_buffer
|
||||
) / len(self.partial_fix_buffer)
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
self.raw_score_buffer = list()
|
||||
self.partial_fix_buffer = list()
|
||||
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
"""Load the HumanEvalFix dataset and prepare train/test splits."""
|
||||
from datasets import load_dataset
|
||||
|
||||
print("Loading HumanEvalFix dataset...")
|
||||
dataset = load_dataset("bigcode/humanevalfix-python", split="test")
|
||||
|
||||
all_items: List[CodeDebugItem] = []
|
||||
for row in dataset:
|
||||
all_items.append(
|
||||
{
|
||||
"task_id": row["task_id"],
|
||||
"prompt": row["declaration"],
|
||||
"buggy_solution": row["buggy_solution"],
|
||||
"canonical_solution": row["canonical_solution"],
|
||||
"test": row["test"],
|
||||
"entry_point": row["entry_point"],
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_items)} problems")
|
||||
|
||||
# Verify a few items actually work with canonical solutions
|
||||
verified = 0
|
||||
for item in all_items[:10]:
|
||||
code = item["prompt"] + item["canonical_solution"]
|
||||
passed, _ = execute_code_with_tests(
|
||||
code, item["test"], item["entry_point"]
|
||||
)
|
||||
if passed:
|
||||
verified += 1
|
||||
print(f"Verified {verified}/10 canonical solutions execute correctly")
|
||||
|
||||
# Split 80/20 train/test
|
||||
random.shuffle(all_items)
|
||||
split_idx = int(len(all_items) * 0.8)
|
||||
self.train = all_items[:split_idx]
|
||||
self.test = all_items[split_idx:]
|
||||
|
||||
print(f"Train: {len(self.train)}, Test: {len(self.test)}")
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
"""Save checkpoint with iteration state."""
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
def _score_fix(self, generated_code: str, item: CodeDebugItem) -> Tuple[float, bool]:
|
||||
"""
|
||||
Score a generated fix by execution against test cases.
|
||||
|
||||
Returns:
|
||||
Tuple of (score, is_partial_fix):
|
||||
- score: 1.0 if all tests pass, proportional for partial,
|
||||
-1.0 if no improvement or compilation error
|
||||
- is_partial_fix: True if some but not all tests pass
|
||||
"""
|
||||
if not generated_code:
|
||||
return -1.0, False
|
||||
|
||||
# Run the fixed code against tests
|
||||
all_passed, error = execute_code_with_tests(
|
||||
generated_code, item["test"], item["entry_point"]
|
||||
)
|
||||
|
||||
if all_passed:
|
||||
return 1.0, False
|
||||
|
||||
# Check for partial credit — how many tests pass?
|
||||
passed, total = count_test_results(
|
||||
generated_code, item["test"], item["entry_point"]
|
||||
)
|
||||
|
||||
if total == 0:
|
||||
return -1.0, False
|
||||
|
||||
# Also check how the buggy code does
|
||||
buggy_code = item["prompt"] + item["buggy_solution"]
|
||||
buggy_passed, buggy_total = count_test_results(
|
||||
buggy_code, item["test"], item["entry_point"]
|
||||
)
|
||||
|
||||
# Score based on improvement over buggy code
|
||||
if passed > buggy_passed:
|
||||
# Partial improvement: scale between -0.5 and 0.9
|
||||
improvement_ratio = (passed - buggy_passed) / max(1, total - buggy_passed)
|
||||
score = -0.5 + 1.4 * improvement_ratio
|
||||
return score, True
|
||||
elif passed == buggy_passed and passed > 0:
|
||||
# No improvement but code at least runs
|
||||
return -0.5, True
|
||||
else:
|
||||
# Made things worse or code doesn't compile
|
||||
return -1.0, False
|
||||
|
||||
async def rollout_and_score_eval(self, item: CodeDebugItem) -> dict:
|
||||
"""Rollout and score a single evaluation item."""
|
||||
user_content = format_debug_prompt(item)
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.6,
|
||||
)
|
||||
response_content = completion.choices[0].message.content
|
||||
|
||||
# Extract and score generated fix
|
||||
generated_code = extract_boxed_code(response_content)
|
||||
score, is_partial = self._score_fix(generated_code, item)
|
||||
correct = score == 1.0
|
||||
|
||||
sample = {
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": response_content},
|
||||
],
|
||||
"question": f"Fix bug in {item['entry_point']}",
|
||||
"gold_answer": item["canonical_solution"],
|
||||
"model_parsed": generated_code or "(no code extracted)",
|
||||
"score": 1 if correct else 0,
|
||||
"correct": correct,
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
}
|
||||
|
||||
return {"score": 1 if correct else 0, "sample": sample}
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Run evaluation on test set."""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
eval_tasks = []
|
||||
for item in self.test[:100]:
|
||||
eval_tasks.append(self.rollout_and_score_eval(item))
|
||||
results = await tqdm_asyncio.gather(*eval_tasks)
|
||||
|
||||
scores = [result["score"] for result in results]
|
||||
samples = [result["sample"] for result in results]
|
||||
|
||||
percent_correct = sum(scores) / len(scores) if scores else 0
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
||||
|
||||
eval_metrics = {
|
||||
"eval/percent_correct": percent_correct,
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": 0.6,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: CodeDebugItem
|
||||
) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
"""Generate code fixes for a buggy function."""
|
||||
user_content = format_debug_prompt(item)
|
||||
user_message = {"role": "user", "content": user_content}
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completions = await managed.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
user_message,
|
||||
],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=1.0,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
messages = (
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
user_message,
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"item": item,
|
||||
"finish_reason": chat_completion.finish_reason,
|
||||
"tokens": nodes[i].tokens,
|
||||
"masks": nodes[i].masked_tokens,
|
||||
"logprobs": nodes[i].logprobs,
|
||||
}
|
||||
)
|
||||
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
"""Score generated code fixes by execution against test cases."""
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
scores["inference_logprobs"] = list()
|
||||
|
||||
item = rollout_group_data[0]["item"]
|
||||
|
||||
random.shuffle(rollout_group_data)
|
||||
|
||||
for rollout in rollout_group_data:
|
||||
response_content = rollout["messages"][-1]["content"]
|
||||
|
||||
# Extract fixed code from \boxed{}
|
||||
generated_code = extract_boxed_code(response_content)
|
||||
|
||||
# Score by execution
|
||||
reward, is_partial = self._score_fix(generated_code, item)
|
||||
self.partial_fix_buffer.append(1 if is_partial else 0)
|
||||
|
||||
tokens = rollout["tokens"]
|
||||
masks = rollout["masks"]
|
||||
logprobs = rollout["logprobs"]
|
||||
|
||||
# Remove obviously bad examples (too short responses)
|
||||
if len([1 for m in masks if m != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["inference_logprobs"].append(logprobs)
|
||||
scores["scores"].append(reward)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
for score in scores["scores"]:
|
||||
self.percent_correct_buffer.append(1.0 if score >= 1.0 else 0.0)
|
||||
self.raw_score_buffer.append(max(score, 0.0))
|
||||
|
||||
# Apply length penalty when all scores are perfect
|
||||
if scores["scores"] and all(score == 1.0 for score in scores["scores"]):
|
||||
token_lengths = [len(t) for t in scores["tokens"]]
|
||||
if max(token_lengths) == 0:
|
||||
return None
|
||||
|
||||
max_allowed_length = self.config.max_token_length
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
|
||||
scores["scores"] = []
|
||||
for length in token_lengths:
|
||||
if length <= length_threshold:
|
||||
scores["scores"].append(1.0)
|
||||
else:
|
||||
pct = (length - length_threshold) / (
|
||||
max_allowed_length - length_threshold
|
||||
)
|
||||
pct = min(pct, 1.0)
|
||||
scores["scores"].append(1.0 - pct)
|
||||
|
||||
# If all scores are same, return None (GRPO needs variance)
|
||||
if scores["scores"] and all(
|
||||
scores["scores"][0] == s for s in scores["scores"]
|
||||
):
|
||||
return None
|
||||
|
||||
return scores
|
||||
|
||||
async def get_next_item(self) -> CodeDebugItem:
|
||||
"""Get the next training item."""
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return next_item
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CodeDebugEnv.cli()
|
||||
213
environments/community/code_debug_env/code_executor.py
Normal file
213
environments/community/code_debug_env/code_executor.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
"""
|
||||
Safe code execution utilities for the Code Debug environment.
|
||||
|
||||
Runs generated code in isolated subprocess with timeout and resource limits.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def execute_code_with_tests(
|
||||
code: str,
|
||||
test_code: str,
|
||||
entry_point: str,
|
||||
timeout: int = 10,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Execute code with test cases in an isolated subprocess.
|
||||
|
||||
Args:
|
||||
code: The function implementation to test.
|
||||
test_code: Test code containing a `check(candidate)` function.
|
||||
entry_point: The function name to pass to `check()`.
|
||||
timeout: Maximum execution time in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_tests_passed, error_message).
|
||||
"""
|
||||
full_code = code + "\n\n" + test_code + f"\n\ncheck({entry_point})\n"
|
||||
|
||||
tmp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".py", delete=False, prefix="code_debug_"
|
||||
) as f:
|
||||
f.write(full_code)
|
||||
tmp_path = f.name
|
||||
|
||||
result = subprocess.run(
|
||||
["python3", tmp_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
else:
|
||||
error = result.stderr.strip()
|
||||
# Truncate long tracebacks
|
||||
if len(error) > 500:
|
||||
error = error[-500:]
|
||||
return False, error
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Execution timed out"
|
||||
except Exception as e:
|
||||
return False, f"{type(e).__name__}: {str(e)[:200]}"
|
||||
finally:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def extract_boxed_code(text: str) -> Optional[str]:
|
||||
"""
|
||||
Extract code from \\boxed{...} format, handling nested braces.
|
||||
|
||||
Finds the LAST occurrence of \\boxed{} to handle cases where
|
||||
the model discusses code before providing the final answer.
|
||||
|
||||
Args:
|
||||
text: The model's response text.
|
||||
|
||||
Returns:
|
||||
The extracted code string, or None if no \\boxed{} found.
|
||||
"""
|
||||
idx = text.rfind("\\boxed{")
|
||||
if idx == -1:
|
||||
return None
|
||||
|
||||
start = idx + len("\\boxed{")
|
||||
brace_count = 1
|
||||
i = start
|
||||
|
||||
while i < len(text) and brace_count > 0:
|
||||
if text[i] == "{":
|
||||
brace_count += 1
|
||||
elif text[i] == "}":
|
||||
brace_count -= 1
|
||||
i += 1
|
||||
|
||||
if brace_count == 0:
|
||||
return text[start : i - 1].strip()
|
||||
return None
|
||||
|
||||
|
||||
def count_test_results(
|
||||
code: str,
|
||||
test_code: str,
|
||||
entry_point: str,
|
||||
timeout: int = 10,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Count how many individual assertions pass vs fail.
|
||||
|
||||
Wraps each assertion in a try/except to count partial success.
|
||||
|
||||
Args:
|
||||
code: The function implementation to test.
|
||||
test_code: Test code containing a `check(candidate)` function.
|
||||
entry_point: The function name to pass.
|
||||
timeout: Maximum execution time in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (passed_count, total_count).
|
||||
"""
|
||||
# Build a counting test harness
|
||||
counter_code = f"""
|
||||
import sys
|
||||
|
||||
{code}
|
||||
|
||||
_passed = 0
|
||||
_total = 0
|
||||
|
||||
def _counting_check(candidate):
|
||||
global _passed, _total
|
||||
import types
|
||||
|
||||
# Get the original check function's code
|
||||
_orig_check_src = '''{test_code}'''
|
||||
_ns = {{'__builtins__': __builtins__}}
|
||||
exec(_orig_check_src, _ns)
|
||||
_orig_check = _ns.get('check')
|
||||
|
||||
if _orig_check is None:
|
||||
print("0/0")
|
||||
return
|
||||
|
||||
# Get the source of check and count assertions
|
||||
import inspect
|
||||
try:
|
||||
src = inspect.getsource(_orig_check)
|
||||
except (TypeError, OSError):
|
||||
# Can't inspect — just run it
|
||||
try:
|
||||
_orig_check(candidate)
|
||||
print("1/1")
|
||||
except Exception:
|
||||
print("0/1")
|
||||
return
|
||||
|
||||
# Count 'assert' lines
|
||||
assert_lines = [l.strip() for l in src.split('\\n') if l.strip().startswith('assert')]
|
||||
_total = max(len(assert_lines), 1)
|
||||
|
||||
# Run the full check — if it passes, all assertions passed
|
||||
try:
|
||||
_orig_check(candidate)
|
||||
_passed = _total
|
||||
except AssertionError:
|
||||
# Some failed — try to count
|
||||
_passed = 0
|
||||
for line in assert_lines:
|
||||
try:
|
||||
exec(line, {{'__builtins__': __builtins__, '{entry_point}': candidate, 'candidate': candidate}})
|
||||
_passed += 1
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
_passed = 0
|
||||
|
||||
print(f"{{_passed}}/{{_total}}")
|
||||
|
||||
_counting_check({entry_point})
|
||||
"""
|
||||
|
||||
tmp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".py", delete=False, prefix="code_debug_count_"
|
||||
) as f:
|
||||
f.write(counter_code)
|
||||
tmp_path = f.name
|
||||
|
||||
result = subprocess.run(
|
||||
["python3", tmp_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
stdout = result.stdout.strip()
|
||||
if "/" in stdout:
|
||||
parts = stdout.split("/")
|
||||
try:
|
||||
return int(parts[0]), int(parts[1])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
# Fallback: if execution succeeded, assume all passed
|
||||
if result.returncode == 0:
|
||||
return 1, 1
|
||||
return 0, 1
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return 0, 1
|
||||
except Exception:
|
||||
return 0, 1
|
||||
finally:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
179
environments/community/code_debug_env/test_code_debug.py
Normal file
179
environments/community/code_debug_env/test_code_debug.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
"""
|
||||
Unit tests for the Code Debug environment.
|
||||
|
||||
Tests the code execution, scoring, and extraction utilities
|
||||
without requiring a running inference server.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from code_executor import count_test_results, execute_code_with_tests, extract_boxed_code
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for extract_boxed_code
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestExtractBoxedCode:
|
||||
def test_simple_extraction(self):
|
||||
text = r"""Here is the fix:
|
||||
\boxed{def add(a, b):
|
||||
return a + b
|
||||
}"""
|
||||
result = extract_boxed_code(text)
|
||||
assert result is not None
|
||||
assert "def add(a, b):" in result
|
||||
assert "return a + b" in result
|
||||
|
||||
def test_nested_braces(self):
|
||||
text = r"""\boxed{def foo(x):
|
||||
if x > 0:
|
||||
return {x: x**2}
|
||||
return {}
|
||||
}"""
|
||||
result = extract_boxed_code(text)
|
||||
assert result is not None
|
||||
assert "return {x: x**2}" in result
|
||||
assert "return {}" in result
|
||||
|
||||
def test_no_boxed(self):
|
||||
text = "Just some text without any boxed content"
|
||||
result = extract_boxed_code(text)
|
||||
assert result is None
|
||||
|
||||
def test_last_boxed_used(self):
|
||||
text = r"""First attempt: \boxed{def bad(): pass}
|
||||
Actually, let me reconsider: \boxed{def good():
|
||||
return 42
|
||||
}"""
|
||||
result = extract_boxed_code(text)
|
||||
assert result is not None
|
||||
assert "def good():" in result
|
||||
assert "return 42" in result
|
||||
|
||||
def test_empty_boxed(self):
|
||||
text = r"\boxed{}"
|
||||
result = extract_boxed_code(text)
|
||||
assert result == ""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for execute_code_with_tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestExecuteCodeWithTests:
|
||||
def test_correct_code_passes(self):
|
||||
code = "def add(a, b):\n return a + b\n"
|
||||
test_code = """def check(candidate):
|
||||
assert candidate(1, 2) == 3
|
||||
assert candidate(-1, 1) == 0
|
||||
assert candidate(0, 0) == 0
|
||||
"""
|
||||
passed, error = execute_code_with_tests(code, test_code, "add")
|
||||
assert passed is True
|
||||
assert error == ""
|
||||
|
||||
def test_buggy_code_fails(self):
|
||||
code = "def add(a, b):\n return a - b\n"
|
||||
test_code = """def check(candidate):
|
||||
assert candidate(1, 2) == 3
|
||||
assert candidate(-1, 1) == 0
|
||||
"""
|
||||
passed, error = execute_code_with_tests(code, test_code, "add")
|
||||
assert passed is False
|
||||
assert error != ""
|
||||
|
||||
def test_syntax_error(self):
|
||||
code = "def add(a, b)\n return a + b\n" # Missing colon
|
||||
test_code = """def check(candidate):
|
||||
assert candidate(1, 2) == 3
|
||||
"""
|
||||
passed, error = execute_code_with_tests(code, test_code, "add")
|
||||
assert passed is False
|
||||
|
||||
def test_infinite_loop_timeout(self):
|
||||
code = "def loop(x):\n while True: pass\n"
|
||||
test_code = """def check(candidate):
|
||||
assert candidate(1) is None
|
||||
"""
|
||||
passed, error = execute_code_with_tests(code, test_code, "loop", timeout=2)
|
||||
assert passed is False
|
||||
assert "timed out" in error.lower() or "Timeout" in error
|
||||
|
||||
def test_runtime_error(self):
|
||||
code = "def divide(a, b):\n return a / b\n"
|
||||
test_code = """def check(candidate):
|
||||
assert candidate(1, 0) == 0
|
||||
"""
|
||||
passed, error = execute_code_with_tests(code, test_code, "divide")
|
||||
assert passed is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for count_test_results
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestCountTestResults:
|
||||
def test_all_pass(self):
|
||||
code = "def add(a, b):\n return a + b\n"
|
||||
test_code = """def check(candidate):
|
||||
assert candidate(1, 2) == 3
|
||||
assert candidate(0, 0) == 0
|
||||
"""
|
||||
passed, total = count_test_results(code, test_code, "add")
|
||||
assert passed > 0
|
||||
assert passed == total
|
||||
|
||||
def test_none_pass(self):
|
||||
code = "def add(a, b):\n return 0\n"
|
||||
test_code = """def check(candidate):
|
||||
assert candidate(1, 2) == 3
|
||||
assert candidate(5, 5) == 10
|
||||
"""
|
||||
passed, total = count_test_results(code, test_code, "add")
|
||||
assert passed == 0
|
||||
|
||||
def test_syntax_error_returns_zero(self):
|
||||
code = "def add(a, b)\n return a + b\n"
|
||||
test_code = """def check(candidate):
|
||||
assert candidate(1, 2) == 3
|
||||
"""
|
||||
passed, total = count_test_results(code, test_code, "add")
|
||||
assert passed == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for scoring logic
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestScoringLogic:
|
||||
"""Test the scoring logic that will be used in CodeDebugEnv._score_fix."""
|
||||
|
||||
def test_perfect_fix_scores_one(self):
|
||||
"""A fix that passes all tests should score 1.0."""
|
||||
code = "def add(a, b):\n return a + b\n"
|
||||
test_code = """def check(candidate):
|
||||
assert candidate(1, 2) == 3
|
||||
assert candidate(-1, 1) == 0
|
||||
"""
|
||||
passed, error = execute_code_with_tests(code, test_code, "add")
|
||||
assert passed is True
|
||||
# Score would be 1.0
|
||||
|
||||
def test_no_fix_scores_negative(self):
|
||||
"""A fix that doesn't improve should score negatively."""
|
||||
buggy = "def add(a, b):\n return a - b\n"
|
||||
test_code = """def check(candidate):
|
||||
assert candidate(1, 2) == 3
|
||||
"""
|
||||
passed, error = execute_code_with_tests(buggy, test_code, "add")
|
||||
assert passed is False
|
||||
# Score would be -1.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Loading…
Add table
Add a link
Reference in a new issue