mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
179 lines
5.6 KiB
Python
179 lines
5.6 KiB
Python
"""
|
|
Unit tests for the Code Debug environment.
|
|
|
|
Tests the code execution, scoring, and extraction utilities
|
|
without requiring a running inference server.
|
|
"""
|
|
|
|
import pytest
|
|
|
|
from code_executor import count_test_results, execute_code_with_tests, extract_boxed_code
|
|
|
|
|
|
# ============================================================================
|
|
# Tests for extract_boxed_code
|
|
# ============================================================================
|
|
|
|
|
|
class TestExtractBoxedCode:
|
|
def test_simple_extraction(self):
|
|
text = r"""Here is the fix:
|
|
\boxed{def add(a, b):
|
|
return a + b
|
|
}"""
|
|
result = extract_boxed_code(text)
|
|
assert result is not None
|
|
assert "def add(a, b):" in result
|
|
assert "return a + b" in result
|
|
|
|
def test_nested_braces(self):
|
|
text = r"""\boxed{def foo(x):
|
|
if x > 0:
|
|
return {x: x**2}
|
|
return {}
|
|
}"""
|
|
result = extract_boxed_code(text)
|
|
assert result is not None
|
|
assert "return {x: x**2}" in result
|
|
assert "return {}" in result
|
|
|
|
def test_no_boxed(self):
|
|
text = "Just some text without any boxed content"
|
|
result = extract_boxed_code(text)
|
|
assert result is None
|
|
|
|
def test_last_boxed_used(self):
|
|
text = r"""First attempt: \boxed{def bad(): pass}
|
|
Actually, let me reconsider: \boxed{def good():
|
|
return 42
|
|
}"""
|
|
result = extract_boxed_code(text)
|
|
assert result is not None
|
|
assert "def good():" in result
|
|
assert "return 42" in result
|
|
|
|
def test_empty_boxed(self):
|
|
text = r"\boxed{}"
|
|
result = extract_boxed_code(text)
|
|
assert result == ""
|
|
|
|
|
|
# ============================================================================
|
|
# Tests for execute_code_with_tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestExecuteCodeWithTests:
|
|
def test_correct_code_passes(self):
|
|
code = "def add(a, b):\n return a + b\n"
|
|
test_code = """def check(candidate):
|
|
assert candidate(1, 2) == 3
|
|
assert candidate(-1, 1) == 0
|
|
assert candidate(0, 0) == 0
|
|
"""
|
|
passed, error = execute_code_with_tests(code, test_code, "add")
|
|
assert passed is True
|
|
assert error == ""
|
|
|
|
def test_buggy_code_fails(self):
|
|
code = "def add(a, b):\n return a - b\n"
|
|
test_code = """def check(candidate):
|
|
assert candidate(1, 2) == 3
|
|
assert candidate(-1, 1) == 0
|
|
"""
|
|
passed, error = execute_code_with_tests(code, test_code, "add")
|
|
assert passed is False
|
|
assert error != ""
|
|
|
|
def test_syntax_error(self):
|
|
code = "def add(a, b)\n return a + b\n" # Missing colon
|
|
test_code = """def check(candidate):
|
|
assert candidate(1, 2) == 3
|
|
"""
|
|
passed, error = execute_code_with_tests(code, test_code, "add")
|
|
assert passed is False
|
|
|
|
def test_infinite_loop_timeout(self):
|
|
code = "def loop(x):\n while True: pass\n"
|
|
test_code = """def check(candidate):
|
|
assert candidate(1) is None
|
|
"""
|
|
passed, error = execute_code_with_tests(code, test_code, "loop", timeout=2)
|
|
assert passed is False
|
|
assert "timed out" in error.lower() or "Timeout" in error
|
|
|
|
def test_runtime_error(self):
|
|
code = "def divide(a, b):\n return a / b\n"
|
|
test_code = """def check(candidate):
|
|
assert candidate(1, 0) == 0
|
|
"""
|
|
passed, error = execute_code_with_tests(code, test_code, "divide")
|
|
assert passed is False
|
|
|
|
|
|
# ============================================================================
|
|
# Tests for count_test_results
|
|
# ============================================================================
|
|
|
|
|
|
class TestCountTestResults:
|
|
def test_all_pass(self):
|
|
code = "def add(a, b):\n return a + b\n"
|
|
test_code = """def check(candidate):
|
|
assert candidate(1, 2) == 3
|
|
assert candidate(0, 0) == 0
|
|
"""
|
|
passed, total = count_test_results(code, test_code, "add")
|
|
assert passed > 0
|
|
assert passed == total
|
|
|
|
def test_none_pass(self):
|
|
code = "def add(a, b):\n return 0\n"
|
|
test_code = """def check(candidate):
|
|
assert candidate(1, 2) == 3
|
|
assert candidate(5, 5) == 10
|
|
"""
|
|
passed, total = count_test_results(code, test_code, "add")
|
|
assert passed == 0
|
|
|
|
def test_syntax_error_returns_zero(self):
|
|
code = "def add(a, b)\n return a + b\n"
|
|
test_code = """def check(candidate):
|
|
assert candidate(1, 2) == 3
|
|
"""
|
|
passed, total = count_test_results(code, test_code, "add")
|
|
assert passed == 0
|
|
|
|
|
|
# ============================================================================
|
|
# Tests for scoring logic
|
|
# ============================================================================
|
|
|
|
|
|
class TestScoringLogic:
|
|
"""Test the scoring logic that will be used in CodeDebugEnv._score_fix."""
|
|
|
|
def test_perfect_fix_scores_one(self):
|
|
"""A fix that passes all tests should score 1.0."""
|
|
code = "def add(a, b):\n return a + b\n"
|
|
test_code = """def check(candidate):
|
|
assert candidate(1, 2) == 3
|
|
assert candidate(-1, 1) == 0
|
|
"""
|
|
passed, error = execute_code_with_tests(code, test_code, "add")
|
|
assert passed is True
|
|
# Score would be 1.0
|
|
|
|
def test_no_fix_scores_negative(self):
|
|
"""A fix that doesn't improve should score negatively."""
|
|
buggy = "def add(a, b):\n return a - b\n"
|
|
test_code = """def check(candidate):
|
|
assert candidate(1, 2) == 3
|
|
"""
|
|
passed, error = execute_code_with_tests(buggy, test_code, "add")
|
|
assert passed is False
|
|
# Score would be -1.0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|