diff --git a/environments/community/code_debug_env/README.md b/environments/community/code_debug_env/README.md new file mode 100644 index 00000000..eed1e2d9 --- /dev/null +++ b/environments/community/code_debug_env/README.md @@ -0,0 +1,81 @@ +# Code Debug Environment + +An Atropos RL environment for training LLMs to debug and fix buggy Python code. + +## Overview + +This environment uses the [HumanEvalFix](https://huggingface.co/datasets/bigcode/humanevalfix-python) dataset, which contains 164 buggy Python functions with associated test suites. The model receives a buggy function and must output the corrected version inside `\boxed{}`. Scoring is done by executing the fixed code against the original test cases. + +## Architecture + +``` +code_debug_env.py # Main env (extends BaseEnv) +code_executor.py # Safe subprocess execution with timeout +test_code_debug.py # Unit tests +README.md # This file +``` + +## Reward Design + +| Outcome | Score | Description | +|---------|-------|-------------| +| All tests pass | **1.0** | Perfect fix | +| Partial improvement | **-0.5 to 0.9** | More tests pass than buggy version | +| No improvement | **-0.5** | Code runs but doesn't fix anything | +| Compilation error / regression | **-1.0** | Fix is worse than the original | + +When all rollouts in a group score 1.0, a **length penalty** is applied to encourage concise solutions (same pattern as `sql_query_env`). + +## Setup + +```bash +# Install dependencies (datasets is the only extra) +pip install datasets + +# Run tests +cd environments/community/code_debug_env +python -m pytest test_code_debug.py -v +``` + +## Usage + +```bash +# Process mode (offline data generation) +python code_debug_env.py process \ + --env.data_path_to_save_groups data/code_debug.jsonl \ + --env.group_size 8 \ + --openai.base_url http://localhost:8000/v1 \ + --openai.model_name "NousResearch/DeepHermes-3-Llama-3-3B-Preview" + +# Serve mode (online RL training) +python code_debug_env.py serve \ + --openai.base_url http://localhost:9001/v1 \ + --openai.model_name "NousResearch/DeepHermes-3-Llama-3-3B-Preview" + +# Evaluate mode +python code_debug_env.py evaluate \ + --openai.base_url http://localhost:8000/v1 \ + --openai.model_name "NousResearch/DeepHermes-3-Llama-3-3B-Preview" +``` + +## WandB Metrics + +| Metric | Description | +|--------|-------------| +| `train/percent_correct` | Fraction of rollouts that pass all tests | +| `train/avg_score` | Average reward across rollouts | +| `train/partial_fix_rate` | Fraction of rollouts that partially fix the code | +| `eval/percent_correct` | Eval set accuracy | + +## Dataset + +- **Source**: [bigcode/humanevalfix-python](https://huggingface.co/datasets/bigcode/humanevalfix-python) +- **License**: Apache 2.0 +- **Size**: 164 problems +- **Split**: 80% train / 20% test + +## Compute Footprint + +- **RAM**: < 1 GB (dataset is small, execution is in subprocess) +- **CPU**: < 5s per verification (subprocess with 10s timeout) +- **GPU**: Only needed for the inference server diff --git a/environments/community/code_debug_env/code_debug_env.py b/environments/community/code_debug_env/code_debug_env.py new file mode 100644 index 00000000..fc9ab8d5 --- /dev/null +++ b/environments/community/code_debug_env/code_debug_env.py @@ -0,0 +1,457 @@ +""" +Code Debug Environment for Atropos + +Trains LLMs to debug and fix buggy Python functions. +Uses the HumanEvalFix dataset with execution-based verification +against ground-truth test cases. + +Environment pattern follows sql_query_env for consistency. +""" + +import random +from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union + +from code_executor import ( + count_test_results, + execute_code_with_tests, + extract_boxed_code, +) +from tqdm.asyncio import tqdm_asyncio + +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) +from atroposlib.type_definitions import Item + +# System prompt following established Atropos patterns +SYSTEM_PROMPT = ( + "You are a deep thinking AI, you may use extremely long chains of thought " + "to deeply consider the problem and deliberate with yourself via systematic " + "reasoning processes to help come to a correct solution prior to answering. " + "You should enclose your thoughts and internal monologue inside " + "tags, and then provide your solution or response to the problem.\n\n" +) + +SYSTEM_PROMPT += """You are an expert Python debugger. Given a buggy Python function \ +and its test cases, identify the bug and provide the corrected function. + +You are allocated a maximum of 2048 tokens, please strive to use less. + +Provide your corrected function inside \\boxed{} like this: +\\boxed{def function_name(args): + # your corrected code here + return result +} + +Important: +- Keep the same function signature +- Only fix the bug, don't rewrite the function from scratch unless necessary +- Ensure the function passes all provided test cases +- Do not include test cases or the check function in your answer + +End your answer with \\boxed{your corrected function here}""" + + +class CodeDebugItem(TypedDict): + """Type definition for a HumanEvalFix dataset item.""" + + task_id: str + prompt: str + buggy_solution: str + canonical_solution: str + test: str + entry_point: str + + +def format_debug_prompt(item: CodeDebugItem) -> str: + """Format the buggy code and context into a prompt for the model.""" + buggy_code = item["prompt"] + item["buggy_solution"] + # Show test structure without revealing the exact assertions + return ( + f"Here is a buggy Python function:\n\n" + f"```python\n{buggy_code}```\n\n" + f"The function `{item['entry_point']}` has a bug. " + f"It fails its test cases.\n\n" + f"Please identify the bug, fix it, and provide the corrected " + f"function inside \\boxed{{}}." + ) + + +class CodeDebugEnv(BaseEnv): + """ + Environment for training LLMs to debug Python code. + + Uses the HumanEvalFix dataset. The model receives a buggy function + and must output the corrected version. Scoring is done by executing + the fixed code against the original test suite. + """ + + name = "code_debug" + + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[APIServerConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.eval_metrics = list() + self.partial_fix_buffer = list() + self.raw_score_buffer = list() + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: + """Initialize default configuration for the environment.""" + env_config = BaseEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + group_size=8, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=1000, + batch_size=12, + steps_per_eval=100, + max_token_length=2048, + wandb_name="code_debug", + ) + server_configs = [ + APIServerConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, + ), + ] + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """Log custom metrics to WandB.""" + if wandb_metrics is None: + wandb_metrics = {} + + # Log percent fully correct (all tests pass) + if self.percent_correct_buffer: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + + # Log average raw score (includes partial credit) + if self.raw_score_buffer: + wandb_metrics["train/avg_score"] = sum(self.raw_score_buffer) / len( + self.raw_score_buffer + ) + + # Log partial fix rate (code runs but doesn't pass all tests) + if self.partial_fix_buffer: + wandb_metrics["train/partial_fix_rate"] = sum( + self.partial_fix_buffer + ) / len(self.partial_fix_buffer) + + self.percent_correct_buffer = list() + self.raw_score_buffer = list() + self.partial_fix_buffer = list() + + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + + await super().wandb_log(wandb_metrics) + + async def setup(self): + """Load the HumanEvalFix dataset and prepare train/test splits.""" + from datasets import load_dataset + + print("Loading HumanEvalFix dataset...") + dataset = load_dataset("bigcode/humanevalfix-python", split="test") + + all_items: List[CodeDebugItem] = [] + for row in dataset: + all_items.append( + { + "task_id": row["task_id"], + "prompt": row["declaration"], + "buggy_solution": row["buggy_solution"], + "canonical_solution": row["canonical_solution"], + "test": row["test"], + "entry_point": row["entry_point"], + } + ) + + print(f"Loaded {len(all_items)} problems") + + # Verify a few items actually work with canonical solutions + verified = 0 + for item in all_items[:10]: + code = item["prompt"] + item["canonical_solution"] + passed, _ = execute_code_with_tests( + code, item["test"], item["entry_point"] + ) + if passed: + verified += 1 + print(f"Verified {verified}/10 canonical solutions execute correctly") + + # Split 80/20 train/test + random.shuffle(all_items) + split_idx = int(len(all_items) * 0.8) + self.train = all_items[:split_idx] + self.test = all_items[split_idx:] + + print(f"Train: {len(self.train)}, Test: {len(self.test)}") + self.iter = 0 + + def save_checkpoint(self, step, data=None): + """Save checkpoint with iteration state.""" + if data is None: + data = {} + data["iter"] = self.iter + super().save_checkpoint(step, data) + + def _score_fix(self, generated_code: str, item: CodeDebugItem) -> Tuple[float, bool]: + """ + Score a generated fix by execution against test cases. + + Returns: + Tuple of (score, is_partial_fix): + - score: 1.0 if all tests pass, proportional for partial, + -1.0 if no improvement or compilation error + - is_partial_fix: True if some but not all tests pass + """ + if not generated_code: + return -1.0, False + + # Run the fixed code against tests + all_passed, error = execute_code_with_tests( + generated_code, item["test"], item["entry_point"] + ) + + if all_passed: + return 1.0, False + + # Check for partial credit — how many tests pass? + passed, total = count_test_results( + generated_code, item["test"], item["entry_point"] + ) + + if total == 0: + return -1.0, False + + # Also check how the buggy code does + buggy_code = item["prompt"] + item["buggy_solution"] + buggy_passed, buggy_total = count_test_results( + buggy_code, item["test"], item["entry_point"] + ) + + # Score based on improvement over buggy code + if passed > buggy_passed: + # Partial improvement: scale between -0.5 and 0.9 + improvement_ratio = (passed - buggy_passed) / max(1, total - buggy_passed) + score = -0.5 + 1.4 * improvement_ratio + return score, True + elif passed == buggy_passed and passed > 0: + # No improvement but code at least runs + return -0.5, True + else: + # Made things worse or code doesn't compile + return -1.0, False + + async def rollout_and_score_eval(self, item: CodeDebugItem) -> dict: + """Rollout and score a single evaluation item.""" + user_content = format_debug_prompt(item) + + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + completion = await managed.chat_completion( + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ], + n=1, + max_tokens=self.config.max_token_length, + temperature=0.6, + ) + response_content = completion.choices[0].message.content + + # Extract and score generated fix + generated_code = extract_boxed_code(response_content) + score, is_partial = self._score_fix(generated_code, item) + correct = score == 1.0 + + sample = { + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + {"role": "assistant", "content": response_content}, + ], + "question": f"Fix bug in {item['entry_point']}", + "gold_answer": item["canonical_solution"], + "model_parsed": generated_code or "(no code extracted)", + "score": 1 if correct else 0, + "correct": correct, + "finish_reason": completion.choices[0].finish_reason, + } + + return {"score": 1 if correct else 0, "sample": sample} + + async def evaluate(self, *args, **kwargs): + """Run evaluation on test set.""" + import time + + start_time = time.time() + + eval_tasks = [] + for item in self.test[:100]: + eval_tasks.append(self.rollout_and_score_eval(item)) + results = await tqdm_asyncio.gather(*eval_tasks) + + scores = [result["score"] for result in results] + samples = [result["sample"] for result in results] + + percent_correct = sum(scores) / len(scores) if scores else 0 + + end_time = time.time() + + self.eval_metrics.append(("eval/percent_correct", percent_correct)) + + eval_metrics = { + "eval/percent_correct": percent_correct, + } + + await self.evaluate_log( + metrics=eval_metrics, + samples=samples, + start_time=start_time, + end_time=end_time, + generation_parameters={ + "temperature": 0.6, + "max_tokens": self.config.max_token_length, + }, + ) + + async def collect_trajectories( + self, item: CodeDebugItem + ) -> Tuple[ScoredDataGroup, list[Item]]: + """Generate code fixes for a buggy function.""" + user_content = format_debug_prompt(item) + user_message = {"role": "user", "content": user_content} + + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + chat_completions = await managed.chat_completion( + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + user_message, + ], + n=self.config.group_size, + max_tokens=self.config.max_token_length, + temperature=1.0, + ) + + state = managed.get_state() + nodes = state["nodes"] + + to_score = list() + to_backlog = list() + + for i, chat_completion in enumerate(chat_completions.choices): + messages = ( + {"role": "system", "content": SYSTEM_PROMPT}, + user_message, + {"role": "assistant", "content": chat_completion.message.content}, + ) + to_score.append( + { + "messages": messages, + "item": item, + "finish_reason": chat_completion.finish_reason, + "tokens": nodes[i].tokens, + "masks": nodes[i].masked_tokens, + "logprobs": nodes[i].logprobs, + } + ) + + to_postprocess = await self.score(to_score) + return to_postprocess, to_backlog + + async def score( + self, rollout_group_data + ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: + """Score generated code fixes by execution against test cases.""" + scores = ScoredDataGroup() + scores["tokens"] = list() + scores["masks"] = list() + scores["scores"] = list() + scores["inference_logprobs"] = list() + + item = rollout_group_data[0]["item"] + + random.shuffle(rollout_group_data) + + for rollout in rollout_group_data: + response_content = rollout["messages"][-1]["content"] + + # Extract fixed code from \boxed{} + generated_code = extract_boxed_code(response_content) + + # Score by execution + reward, is_partial = self._score_fix(generated_code, item) + self.partial_fix_buffer.append(1 if is_partial else 0) + + tokens = rollout["tokens"] + masks = rollout["masks"] + logprobs = rollout["logprobs"] + + # Remove obviously bad examples (too short responses) + if len([1 for m in masks if m != -100]) < 10: + continue + + scores["tokens"].append(tokens) + scores["masks"].append(masks) + scores["inference_logprobs"].append(logprobs) + scores["scores"].append(reward) + + if len(scores["tokens"]) >= self.config.group_size: + break + + for score in scores["scores"]: + self.percent_correct_buffer.append(1.0 if score >= 1.0 else 0.0) + self.raw_score_buffer.append(max(score, 0.0)) + + # Apply length penalty when all scores are perfect + if scores["scores"] and all(score == 1.0 for score in scores["scores"]): + token_lengths = [len(t) for t in scores["tokens"]] + if max(token_lengths) == 0: + return None + + max_allowed_length = self.config.max_token_length + length_threshold = max_allowed_length * 0.5 + + scores["scores"] = [] + for length in token_lengths: + if length <= length_threshold: + scores["scores"].append(1.0) + else: + pct = (length - length_threshold) / ( + max_allowed_length - length_threshold + ) + pct = min(pct, 1.0) + scores["scores"].append(1.0 - pct) + + # If all scores are same, return None (GRPO needs variance) + if scores["scores"] and all( + scores["scores"][0] == s for s in scores["scores"] + ): + return None + + return scores + + async def get_next_item(self) -> CodeDebugItem: + """Get the next training item.""" + next_item = self.train[self.iter % len(self.train)] + self.iter += 1 + return next_item + + +if __name__ == "__main__": + CodeDebugEnv.cli() diff --git a/environments/community/code_debug_env/code_executor.py b/environments/community/code_debug_env/code_executor.py new file mode 100644 index 00000000..755cd1eb --- /dev/null +++ b/environments/community/code_debug_env/code_executor.py @@ -0,0 +1,213 @@ +""" +Safe code execution utilities for the Code Debug environment. + +Runs generated code in isolated subprocess with timeout and resource limits. +""" + +import os +import subprocess +import tempfile +from typing import Optional, Tuple + + +def execute_code_with_tests( + code: str, + test_code: str, + entry_point: str, + timeout: int = 10, +) -> Tuple[bool, str]: + """ + Execute code with test cases in an isolated subprocess. + + Args: + code: The function implementation to test. + test_code: Test code containing a `check(candidate)` function. + entry_point: The function name to pass to `check()`. + timeout: Maximum execution time in seconds. + + Returns: + Tuple of (all_tests_passed, error_message). + """ + full_code = code + "\n\n" + test_code + f"\n\ncheck({entry_point})\n" + + tmp_path = None + try: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False, prefix="code_debug_" + ) as f: + f.write(full_code) + tmp_path = f.name + + result = subprocess.run( + ["python3", tmp_path], + capture_output=True, + text=True, + timeout=timeout, + ) + if result.returncode == 0: + return True, "" + else: + error = result.stderr.strip() + # Truncate long tracebacks + if len(error) > 500: + error = error[-500:] + return False, error + + except subprocess.TimeoutExpired: + return False, "Execution timed out" + except Exception as e: + return False, f"{type(e).__name__}: {str(e)[:200]}" + finally: + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) + + +def extract_boxed_code(text: str) -> Optional[str]: + """ + Extract code from \\boxed{...} format, handling nested braces. + + Finds the LAST occurrence of \\boxed{} to handle cases where + the model discusses code before providing the final answer. + + Args: + text: The model's response text. + + Returns: + The extracted code string, or None if no \\boxed{} found. + """ + idx = text.rfind("\\boxed{") + if idx == -1: + return None + + start = idx + len("\\boxed{") + brace_count = 1 + i = start + + while i < len(text) and brace_count > 0: + if text[i] == "{": + brace_count += 1 + elif text[i] == "}": + brace_count -= 1 + i += 1 + + if brace_count == 0: + return text[start : i - 1].strip() + return None + + +def count_test_results( + code: str, + test_code: str, + entry_point: str, + timeout: int = 10, +) -> Tuple[int, int]: + """ + Count how many individual assertions pass vs fail. + + Wraps each assertion in a try/except to count partial success. + + Args: + code: The function implementation to test. + test_code: Test code containing a `check(candidate)` function. + entry_point: The function name to pass. + timeout: Maximum execution time in seconds. + + Returns: + Tuple of (passed_count, total_count). + """ + # Build a counting test harness + counter_code = f""" +import sys + +{code} + +_passed = 0 +_total = 0 + +def _counting_check(candidate): + global _passed, _total + import types + + # Get the original check function's code + _orig_check_src = '''{test_code}''' + _ns = {{'__builtins__': __builtins__}} + exec(_orig_check_src, _ns) + _orig_check = _ns.get('check') + + if _orig_check is None: + print("0/0") + return + + # Get the source of check and count assertions + import inspect + try: + src = inspect.getsource(_orig_check) + except (TypeError, OSError): + # Can't inspect — just run it + try: + _orig_check(candidate) + print("1/1") + except Exception: + print("0/1") + return + + # Count 'assert' lines + assert_lines = [l.strip() for l in src.split('\\n') if l.strip().startswith('assert')] + _total = max(len(assert_lines), 1) + + # Run the full check — if it passes, all assertions passed + try: + _orig_check(candidate) + _passed = _total + except AssertionError: + # Some failed — try to count + _passed = 0 + for line in assert_lines: + try: + exec(line, {{'__builtins__': __builtins__, '{entry_point}': candidate, 'candidate': candidate}}) + _passed += 1 + except Exception: + pass + except Exception: + _passed = 0 + + print(f"{{_passed}}/{{_total}}") + +_counting_check({entry_point}) +""" + + tmp_path = None + try: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False, prefix="code_debug_count_" + ) as f: + f.write(counter_code) + tmp_path = f.name + + result = subprocess.run( + ["python3", tmp_path], + capture_output=True, + text=True, + timeout=timeout, + ) + + stdout = result.stdout.strip() + if "/" in stdout: + parts = stdout.split("/") + try: + return int(parts[0]), int(parts[1]) + except (ValueError, IndexError): + pass + + # Fallback: if execution succeeded, assume all passed + if result.returncode == 0: + return 1, 1 + return 0, 1 + + except subprocess.TimeoutExpired: + return 0, 1 + except Exception: + return 0, 1 + finally: + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) diff --git a/environments/community/code_debug_env/test_code_debug.py b/environments/community/code_debug_env/test_code_debug.py new file mode 100644 index 00000000..449dd0de --- /dev/null +++ b/environments/community/code_debug_env/test_code_debug.py @@ -0,0 +1,179 @@ +""" +Unit tests for the Code Debug environment. + +Tests the code execution, scoring, and extraction utilities +without requiring a running inference server. +""" + +import pytest + +from code_executor import count_test_results, execute_code_with_tests, extract_boxed_code + + +# ============================================================================ +# Tests for extract_boxed_code +# ============================================================================ + + +class TestExtractBoxedCode: + def test_simple_extraction(self): + text = r"""Here is the fix: +\boxed{def add(a, b): + return a + b +}""" + result = extract_boxed_code(text) + assert result is not None + assert "def add(a, b):" in result + assert "return a + b" in result + + def test_nested_braces(self): + text = r"""\boxed{def foo(x): + if x > 0: + return {x: x**2} + return {} +}""" + result = extract_boxed_code(text) + assert result is not None + assert "return {x: x**2}" in result + assert "return {}" in result + + def test_no_boxed(self): + text = "Just some text without any boxed content" + result = extract_boxed_code(text) + assert result is None + + def test_last_boxed_used(self): + text = r"""First attempt: \boxed{def bad(): pass} +Actually, let me reconsider: \boxed{def good(): + return 42 +}""" + result = extract_boxed_code(text) + assert result is not None + assert "def good():" in result + assert "return 42" in result + + def test_empty_boxed(self): + text = r"\boxed{}" + result = extract_boxed_code(text) + assert result == "" + + +# ============================================================================ +# Tests for execute_code_with_tests +# ============================================================================ + + +class TestExecuteCodeWithTests: + def test_correct_code_passes(self): + code = "def add(a, b):\n return a + b\n" + test_code = """def check(candidate): + assert candidate(1, 2) == 3 + assert candidate(-1, 1) == 0 + assert candidate(0, 0) == 0 +""" + passed, error = execute_code_with_tests(code, test_code, "add") + assert passed is True + assert error == "" + + def test_buggy_code_fails(self): + code = "def add(a, b):\n return a - b\n" + test_code = """def check(candidate): + assert candidate(1, 2) == 3 + assert candidate(-1, 1) == 0 +""" + passed, error = execute_code_with_tests(code, test_code, "add") + assert passed is False + assert error != "" + + def test_syntax_error(self): + code = "def add(a, b)\n return a + b\n" # Missing colon + test_code = """def check(candidate): + assert candidate(1, 2) == 3 +""" + passed, error = execute_code_with_tests(code, test_code, "add") + assert passed is False + + def test_infinite_loop_timeout(self): + code = "def loop(x):\n while True: pass\n" + test_code = """def check(candidate): + assert candidate(1) is None +""" + passed, error = execute_code_with_tests(code, test_code, "loop", timeout=2) + assert passed is False + assert "timed out" in error.lower() or "Timeout" in error + + def test_runtime_error(self): + code = "def divide(a, b):\n return a / b\n" + test_code = """def check(candidate): + assert candidate(1, 0) == 0 +""" + passed, error = execute_code_with_tests(code, test_code, "divide") + assert passed is False + + +# ============================================================================ +# Tests for count_test_results +# ============================================================================ + + +class TestCountTestResults: + def test_all_pass(self): + code = "def add(a, b):\n return a + b\n" + test_code = """def check(candidate): + assert candidate(1, 2) == 3 + assert candidate(0, 0) == 0 +""" + passed, total = count_test_results(code, test_code, "add") + assert passed > 0 + assert passed == total + + def test_none_pass(self): + code = "def add(a, b):\n return 0\n" + test_code = """def check(candidate): + assert candidate(1, 2) == 3 + assert candidate(5, 5) == 10 +""" + passed, total = count_test_results(code, test_code, "add") + assert passed == 0 + + def test_syntax_error_returns_zero(self): + code = "def add(a, b)\n return a + b\n" + test_code = """def check(candidate): + assert candidate(1, 2) == 3 +""" + passed, total = count_test_results(code, test_code, "add") + assert passed == 0 + + +# ============================================================================ +# Tests for scoring logic +# ============================================================================ + + +class TestScoringLogic: + """Test the scoring logic that will be used in CodeDebugEnv._score_fix.""" + + def test_perfect_fix_scores_one(self): + """A fix that passes all tests should score 1.0.""" + code = "def add(a, b):\n return a + b\n" + test_code = """def check(candidate): + assert candidate(1, 2) == 3 + assert candidate(-1, 1) == 0 +""" + passed, error = execute_code_with_tests(code, test_code, "add") + assert passed is True + # Score would be 1.0 + + def test_no_fix_scores_negative(self): + """A fix that doesn't improve should score negatively.""" + buggy = "def add(a, b):\n return a - b\n" + test_code = """def check(candidate): + assert candidate(1, 2) == 3 +""" + passed, error = execute_code_with_tests(buggy, test_code, "add") + assert passed is False + # Score would be -1.0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])