feat: add code_debug community environment

This commit is contained in:
RUFFY-369 2026-03-24 13:05:15 +05:30
parent c421582b6f
commit 590e8a1ef2
4 changed files with 930 additions and 0 deletions

View file

@ -0,0 +1,81 @@
# Code Debug Environment
An Atropos RL environment for training LLMs to debug and fix buggy Python code.
## Overview
This environment uses the [HumanEvalFix](https://huggingface.co/datasets/bigcode/humanevalfix-python) dataset, which contains 164 buggy Python functions with associated test suites. The model receives a buggy function and must output the corrected version inside `\boxed{}`. Scoring is done by executing the fixed code against the original test cases.
## Architecture
```
code_debug_env.py # Main env (extends BaseEnv)
code_executor.py # Safe subprocess execution with timeout
test_code_debug.py # Unit tests
README.md # This file
```
## Reward Design
| Outcome | Score | Description |
|---------|-------|-------------|
| All tests pass | **1.0** | Perfect fix |
| Partial improvement | **-0.5 to 0.9** | More tests pass than buggy version |
| No improvement | **-0.5** | Code runs but doesn't fix anything |
| Compilation error / regression | **-1.0** | Fix is worse than the original |
When all rollouts in a group score 1.0, a **length penalty** is applied to encourage concise solutions (same pattern as `sql_query_env`).
## Setup
```bash
# Install dependencies (datasets is the only extra)
pip install datasets
# Run tests
cd environments/community/code_debug_env
python -m pytest test_code_debug.py -v
```
## Usage
```bash
# Process mode (offline data generation)
python code_debug_env.py process \
--env.data_path_to_save_groups data/code_debug.jsonl \
--env.group_size 8 \
--openai.base_url http://localhost:8000/v1 \
--openai.model_name "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
# Serve mode (online RL training)
python code_debug_env.py serve \
--openai.base_url http://localhost:9001/v1 \
--openai.model_name "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
# Evaluate mode
python code_debug_env.py evaluate \
--openai.base_url http://localhost:8000/v1 \
--openai.model_name "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
```
## WandB Metrics
| Metric | Description |
|--------|-------------|
| `train/percent_correct` | Fraction of rollouts that pass all tests |
| `train/avg_score` | Average reward across rollouts |
| `train/partial_fix_rate` | Fraction of rollouts that partially fix the code |
| `eval/percent_correct` | Eval set accuracy |
## Dataset
- **Source**: [bigcode/humanevalfix-python](https://huggingface.co/datasets/bigcode/humanevalfix-python)
- **License**: Apache 2.0
- **Size**: 164 problems
- **Split**: 80% train / 20% test
## Compute Footprint
- **RAM**: < 1 GB (dataset is small, execution is in subprocess)
- **CPU**: < 5s per verification (subprocess with 10s timeout)
- **GPU**: Only needed for the inference server

View file

@ -0,0 +1,457 @@
"""
Code Debug Environment for Atropos
Trains LLMs to debug and fix buggy Python functions.
Uses the HumanEvalFix dataset with execution-based verification
against ground-truth test cases.
Environment pattern follows sql_query_env for consistency.
"""
import random
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
from code_executor import (
count_test_results,
execute_code_with_tests,
extract_boxed_code,
)
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import Item
# System prompt following established Atropos patterns
SYSTEM_PROMPT = (
"You are a deep thinking AI, you may use extremely long chains of thought "
"to deeply consider the problem and deliberate with yourself via systematic "
"reasoning processes to help come to a correct solution prior to answering. "
"You should enclose your thoughts and internal monologue inside <think> </think> "
"tags, and then provide your solution or response to the problem.\n\n"
)
SYSTEM_PROMPT += """You are an expert Python debugger. Given a buggy Python function \
and its test cases, identify the bug and provide the corrected function.
You are allocated a maximum of 2048 tokens, please strive to use less.
Provide your corrected function inside \\boxed{} like this:
\\boxed{def function_name(args):
# your corrected code here
return result
}
Important:
- Keep the same function signature
- Only fix the bug, don't rewrite the function from scratch unless necessary
- Ensure the function passes all provided test cases
- Do not include test cases or the check function in your answer
End your answer with \\boxed{your corrected function here}"""
class CodeDebugItem(TypedDict):
"""Type definition for a HumanEvalFix dataset item."""
task_id: str
prompt: str
buggy_solution: str
canonical_solution: str
test: str
entry_point: str
def format_debug_prompt(item: CodeDebugItem) -> str:
"""Format the buggy code and context into a prompt for the model."""
buggy_code = item["prompt"] + item["buggy_solution"]
# Show test structure without revealing the exact assertions
return (
f"Here is a buggy Python function:\n\n"
f"```python\n{buggy_code}```\n\n"
f"The function `{item['entry_point']}` has a bug. "
f"It fails its test cases.\n\n"
f"Please identify the bug, fix it, and provide the corrected "
f"function inside \\boxed{{}}."
)
class CodeDebugEnv(BaseEnv):
"""
Environment for training LLMs to debug Python code.
Uses the HumanEvalFix dataset. The model receives a buggy function
and must output the corrected version. Scoring is done by executing
the fixed code against the original test suite.
"""
name = "code_debug"
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
self.partial_fix_buffer = list()
self.raw_score_buffer = list()
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
"""Initialize default configuration for the environment."""
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
wandb_name="code_debug",
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log custom metrics to WandB."""
if wandb_metrics is None:
wandb_metrics = {}
# Log percent fully correct (all tests pass)
if self.percent_correct_buffer:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
# Log average raw score (includes partial credit)
if self.raw_score_buffer:
wandb_metrics["train/avg_score"] = sum(self.raw_score_buffer) / len(
self.raw_score_buffer
)
# Log partial fix rate (code runs but doesn't pass all tests)
if self.partial_fix_buffer:
wandb_metrics["train/partial_fix_rate"] = sum(
self.partial_fix_buffer
) / len(self.partial_fix_buffer)
self.percent_correct_buffer = list()
self.raw_score_buffer = list()
self.partial_fix_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
await super().wandb_log(wandb_metrics)
async def setup(self):
"""Load the HumanEvalFix dataset and prepare train/test splits."""
from datasets import load_dataset
print("Loading HumanEvalFix dataset...")
dataset = load_dataset("bigcode/humanevalfix-python", split="test")
all_items: List[CodeDebugItem] = []
for row in dataset:
all_items.append(
{
"task_id": row["task_id"],
"prompt": row["declaration"],
"buggy_solution": row["buggy_solution"],
"canonical_solution": row["canonical_solution"],
"test": row["test"],
"entry_point": row["entry_point"],
}
)
print(f"Loaded {len(all_items)} problems")
# Verify a few items actually work with canonical solutions
verified = 0
for item in all_items[:10]:
code = item["prompt"] + item["canonical_solution"]
passed, _ = execute_code_with_tests(
code, item["test"], item["entry_point"]
)
if passed:
verified += 1
print(f"Verified {verified}/10 canonical solutions execute correctly")
# Split 80/20 train/test
random.shuffle(all_items)
split_idx = int(len(all_items) * 0.8)
self.train = all_items[:split_idx]
self.test = all_items[split_idx:]
print(f"Train: {len(self.train)}, Test: {len(self.test)}")
self.iter = 0
def save_checkpoint(self, step, data=None):
"""Save checkpoint with iteration state."""
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
def _score_fix(self, generated_code: str, item: CodeDebugItem) -> Tuple[float, bool]:
"""
Score a generated fix by execution against test cases.
Returns:
Tuple of (score, is_partial_fix):
- score: 1.0 if all tests pass, proportional for partial,
-1.0 if no improvement or compilation error
- is_partial_fix: True if some but not all tests pass
"""
if not generated_code:
return -1.0, False
# Run the fixed code against tests
all_passed, error = execute_code_with_tests(
generated_code, item["test"], item["entry_point"]
)
if all_passed:
return 1.0, False
# Check for partial credit — how many tests pass?
passed, total = count_test_results(
generated_code, item["test"], item["entry_point"]
)
if total == 0:
return -1.0, False
# Also check how the buggy code does
buggy_code = item["prompt"] + item["buggy_solution"]
buggy_passed, buggy_total = count_test_results(
buggy_code, item["test"], item["entry_point"]
)
# Score based on improvement over buggy code
if passed > buggy_passed:
# Partial improvement: scale between -0.5 and 0.9
improvement_ratio = (passed - buggy_passed) / max(1, total - buggy_passed)
score = -0.5 + 1.4 * improvement_ratio
return score, True
elif passed == buggy_passed and passed > 0:
# No improvement but code at least runs
return -0.5, True
else:
# Made things worse or code doesn't compile
return -1.0, False
async def rollout_and_score_eval(self, item: CodeDebugItem) -> dict:
"""Rollout and score a single evaluation item."""
user_content = format_debug_prompt(item)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.chat_completion(
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.6,
)
response_content = completion.choices[0].message.content
# Extract and score generated fix
generated_code = extract_boxed_code(response_content)
score, is_partial = self._score_fix(generated_code, item)
correct = score == 1.0
sample = {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
{"role": "assistant", "content": response_content},
],
"question": f"Fix bug in {item['entry_point']}",
"gold_answer": item["canonical_solution"],
"model_parsed": generated_code or "(no code extracted)",
"score": 1 if correct else 0,
"correct": correct,
"finish_reason": completion.choices[0].finish_reason,
}
return {"score": 1 if correct else 0, "sample": sample}
async def evaluate(self, *args, **kwargs):
"""Run evaluation on test set."""
import time
start_time = time.time()
eval_tasks = []
for item in self.test[:100]:
eval_tasks.append(self.rollout_and_score_eval(item))
results = await tqdm_asyncio.gather(*eval_tasks)
scores = [result["score"] for result in results]
samples = [result["sample"] for result in results]
percent_correct = sum(scores) / len(scores) if scores else 0
end_time = time.time()
self.eval_metrics.append(("eval/percent_correct", percent_correct))
eval_metrics = {
"eval/percent_correct": percent_correct,
}
await self.evaluate_log(
metrics=eval_metrics,
samples=samples,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": 0.6,
"max_tokens": self.config.max_token_length,
},
)
async def collect_trajectories(
self, item: CodeDebugItem
) -> Tuple[ScoredDataGroup, list[Item]]:
"""Generate code fixes for a buggy function."""
user_content = format_debug_prompt(item)
user_message = {"role": "user", "content": user_content}
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
chat_completions = await managed.chat_completion(
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
user_message,
],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=1.0,
)
state = managed.get_state()
nodes = state["nodes"]
to_score = list()
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
messages = (
{"role": "system", "content": SYSTEM_PROMPT},
user_message,
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append(
{
"messages": messages,
"item": item,
"finish_reason": chat_completion.finish_reason,
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog
async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
"""Score generated code fixes by execution against test cases."""
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
scores["inference_logprobs"] = list()
item = rollout_group_data[0]["item"]
random.shuffle(rollout_group_data)
for rollout in rollout_group_data:
response_content = rollout["messages"][-1]["content"]
# Extract fixed code from \boxed{}
generated_code = extract_boxed_code(response_content)
# Score by execution
reward, is_partial = self._score_fix(generated_code, item)
self.partial_fix_buffer.append(1 if is_partial else 0)
tokens = rollout["tokens"]
masks = rollout["masks"]
logprobs = rollout["logprobs"]
# Remove obviously bad examples (too short responses)
if len([1 for m in masks if m != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["inference_logprobs"].append(logprobs)
scores["scores"].append(reward)
if len(scores["tokens"]) >= self.config.group_size:
break
for score in scores["scores"]:
self.percent_correct_buffer.append(1.0 if score >= 1.0 else 0.0)
self.raw_score_buffer.append(max(score, 0.0))
# Apply length penalty when all scores are perfect
if scores["scores"] and all(score == 1.0 for score in scores["scores"]):
token_lengths = [len(t) for t in scores["tokens"]]
if max(token_lengths) == 0:
return None
max_allowed_length = self.config.max_token_length
length_threshold = max_allowed_length * 0.5
scores["scores"] = []
for length in token_lengths:
if length <= length_threshold:
scores["scores"].append(1.0)
else:
pct = (length - length_threshold) / (
max_allowed_length - length_threshold
)
pct = min(pct, 1.0)
scores["scores"].append(1.0 - pct)
# If all scores are same, return None (GRPO needs variance)
if scores["scores"] and all(
scores["scores"][0] == s for s in scores["scores"]
):
return None
return scores
async def get_next_item(self) -> CodeDebugItem:
"""Get the next training item."""
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
return next_item
if __name__ == "__main__":
CodeDebugEnv.cli()

View file

@ -0,0 +1,213 @@
"""
Safe code execution utilities for the Code Debug environment.
Runs generated code in isolated subprocess with timeout and resource limits.
"""
import os
import subprocess
import tempfile
from typing import Optional, Tuple
def execute_code_with_tests(
code: str,
test_code: str,
entry_point: str,
timeout: int = 10,
) -> Tuple[bool, str]:
"""
Execute code with test cases in an isolated subprocess.
Args:
code: The function implementation to test.
test_code: Test code containing a `check(candidate)` function.
entry_point: The function name to pass to `check()`.
timeout: Maximum execution time in seconds.
Returns:
Tuple of (all_tests_passed, error_message).
"""
full_code = code + "\n\n" + test_code + f"\n\ncheck({entry_point})\n"
tmp_path = None
try:
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False, prefix="code_debug_"
) as f:
f.write(full_code)
tmp_path = f.name
result = subprocess.run(
["python3", tmp_path],
capture_output=True,
text=True,
timeout=timeout,
)
if result.returncode == 0:
return True, ""
else:
error = result.stderr.strip()
# Truncate long tracebacks
if len(error) > 500:
error = error[-500:]
return False, error
except subprocess.TimeoutExpired:
return False, "Execution timed out"
except Exception as e:
return False, f"{type(e).__name__}: {str(e)[:200]}"
finally:
if tmp_path and os.path.exists(tmp_path):
os.unlink(tmp_path)
def extract_boxed_code(text: str) -> Optional[str]:
"""
Extract code from \\boxed{...} format, handling nested braces.
Finds the LAST occurrence of \\boxed{} to handle cases where
the model discusses code before providing the final answer.
Args:
text: The model's response text.
Returns:
The extracted code string, or None if no \\boxed{} found.
"""
idx = text.rfind("\\boxed{")
if idx == -1:
return None
start = idx + len("\\boxed{")
brace_count = 1
i = start
while i < len(text) and brace_count > 0:
if text[i] == "{":
brace_count += 1
elif text[i] == "}":
brace_count -= 1
i += 1
if brace_count == 0:
return text[start : i - 1].strip()
return None
def count_test_results(
code: str,
test_code: str,
entry_point: str,
timeout: int = 10,
) -> Tuple[int, int]:
"""
Count how many individual assertions pass vs fail.
Wraps each assertion in a try/except to count partial success.
Args:
code: The function implementation to test.
test_code: Test code containing a `check(candidate)` function.
entry_point: The function name to pass.
timeout: Maximum execution time in seconds.
Returns:
Tuple of (passed_count, total_count).
"""
# Build a counting test harness
counter_code = f"""
import sys
{code}
_passed = 0
_total = 0
def _counting_check(candidate):
global _passed, _total
import types
# Get the original check function's code
_orig_check_src = '''{test_code}'''
_ns = {{'__builtins__': __builtins__}}
exec(_orig_check_src, _ns)
_orig_check = _ns.get('check')
if _orig_check is None:
print("0/0")
return
# Get the source of check and count assertions
import inspect
try:
src = inspect.getsource(_orig_check)
except (TypeError, OSError):
# Can't inspect — just run it
try:
_orig_check(candidate)
print("1/1")
except Exception:
print("0/1")
return
# Count 'assert' lines
assert_lines = [l.strip() for l in src.split('\\n') if l.strip().startswith('assert')]
_total = max(len(assert_lines), 1)
# Run the full check — if it passes, all assertions passed
try:
_orig_check(candidate)
_passed = _total
except AssertionError:
# Some failed — try to count
_passed = 0
for line in assert_lines:
try:
exec(line, {{'__builtins__': __builtins__, '{entry_point}': candidate, 'candidate': candidate}})
_passed += 1
except Exception:
pass
except Exception:
_passed = 0
print(f"{{_passed}}/{{_total}}")
_counting_check({entry_point})
"""
tmp_path = None
try:
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False, prefix="code_debug_count_"
) as f:
f.write(counter_code)
tmp_path = f.name
result = subprocess.run(
["python3", tmp_path],
capture_output=True,
text=True,
timeout=timeout,
)
stdout = result.stdout.strip()
if "/" in stdout:
parts = stdout.split("/")
try:
return int(parts[0]), int(parts[1])
except (ValueError, IndexError):
pass
# Fallback: if execution succeeded, assume all passed
if result.returncode == 0:
return 1, 1
return 0, 1
except subprocess.TimeoutExpired:
return 0, 1
except Exception:
return 0, 1
finally:
if tmp_path and os.path.exists(tmp_path):
os.unlink(tmp_path)

View file

@ -0,0 +1,179 @@
"""
Unit tests for the Code Debug environment.
Tests the code execution, scoring, and extraction utilities
without requiring a running inference server.
"""
import pytest
from code_executor import count_test_results, execute_code_with_tests, extract_boxed_code
# ============================================================================
# Tests for extract_boxed_code
# ============================================================================
class TestExtractBoxedCode:
def test_simple_extraction(self):
text = r"""Here is the fix:
\boxed{def add(a, b):
return a + b
}"""
result = extract_boxed_code(text)
assert result is not None
assert "def add(a, b):" in result
assert "return a + b" in result
def test_nested_braces(self):
text = r"""\boxed{def foo(x):
if x > 0:
return {x: x**2}
return {}
}"""
result = extract_boxed_code(text)
assert result is not None
assert "return {x: x**2}" in result
assert "return {}" in result
def test_no_boxed(self):
text = "Just some text without any boxed content"
result = extract_boxed_code(text)
assert result is None
def test_last_boxed_used(self):
text = r"""First attempt: \boxed{def bad(): pass}
Actually, let me reconsider: \boxed{def good():
return 42
}"""
result = extract_boxed_code(text)
assert result is not None
assert "def good():" in result
assert "return 42" in result
def test_empty_boxed(self):
text = r"\boxed{}"
result = extract_boxed_code(text)
assert result == ""
# ============================================================================
# Tests for execute_code_with_tests
# ============================================================================
class TestExecuteCodeWithTests:
def test_correct_code_passes(self):
code = "def add(a, b):\n return a + b\n"
test_code = """def check(candidate):
assert candidate(1, 2) == 3
assert candidate(-1, 1) == 0
assert candidate(0, 0) == 0
"""
passed, error = execute_code_with_tests(code, test_code, "add")
assert passed is True
assert error == ""
def test_buggy_code_fails(self):
code = "def add(a, b):\n return a - b\n"
test_code = """def check(candidate):
assert candidate(1, 2) == 3
assert candidate(-1, 1) == 0
"""
passed, error = execute_code_with_tests(code, test_code, "add")
assert passed is False
assert error != ""
def test_syntax_error(self):
code = "def add(a, b)\n return a + b\n" # Missing colon
test_code = """def check(candidate):
assert candidate(1, 2) == 3
"""
passed, error = execute_code_with_tests(code, test_code, "add")
assert passed is False
def test_infinite_loop_timeout(self):
code = "def loop(x):\n while True: pass\n"
test_code = """def check(candidate):
assert candidate(1) is None
"""
passed, error = execute_code_with_tests(code, test_code, "loop", timeout=2)
assert passed is False
assert "timed out" in error.lower() or "Timeout" in error
def test_runtime_error(self):
code = "def divide(a, b):\n return a / b\n"
test_code = """def check(candidate):
assert candidate(1, 0) == 0
"""
passed, error = execute_code_with_tests(code, test_code, "divide")
assert passed is False
# ============================================================================
# Tests for count_test_results
# ============================================================================
class TestCountTestResults:
def test_all_pass(self):
code = "def add(a, b):\n return a + b\n"
test_code = """def check(candidate):
assert candidate(1, 2) == 3
assert candidate(0, 0) == 0
"""
passed, total = count_test_results(code, test_code, "add")
assert passed > 0
assert passed == total
def test_none_pass(self):
code = "def add(a, b):\n return 0\n"
test_code = """def check(candidate):
assert candidate(1, 2) == 3
assert candidate(5, 5) == 10
"""
passed, total = count_test_results(code, test_code, "add")
assert passed == 0
def test_syntax_error_returns_zero(self):
code = "def add(a, b)\n return a + b\n"
test_code = """def check(candidate):
assert candidate(1, 2) == 3
"""
passed, total = count_test_results(code, test_code, "add")
assert passed == 0
# ============================================================================
# Tests for scoring logic
# ============================================================================
class TestScoringLogic:
"""Test the scoring logic that will be used in CodeDebugEnv._score_fix."""
def test_perfect_fix_scores_one(self):
"""A fix that passes all tests should score 1.0."""
code = "def add(a, b):\n return a + b\n"
test_code = """def check(candidate):
assert candidate(1, 2) == 3
assert candidate(-1, 1) == 0
"""
passed, error = execute_code_with_tests(code, test_code, "add")
assert passed is True
# Score would be 1.0
def test_no_fix_scores_negative(self):
"""A fix that doesn't improve should score negatively."""
buggy = "def add(a, b):\n return a - b\n"
test_code = """def check(candidate):
assert candidate(1, 2) == 3
"""
passed, error = execute_code_with_tests(buggy, test_code, "add")
assert passed is False
# Score would be -1.0
if __name__ == "__main__":
pytest.main([__file__, "-v"])