atropos/environments/community/sql_query_env/test_sql_executor.py
2026-01-06 15:14:26 +01:00

224 lines
6.8 KiB
Python

"""
Unit tests for SQL executor module.
"""
from sql_executor import (
create_table_from_wikisql,
execute_sql,
extract_boxed_sql,
normalize_result,
results_match,
)
class TestCreateTableFromWikiSQL:
"""Tests for create_table_from_wikisql function."""
def test_basic_table_creation(self):
"""Test creating a simple table."""
header = ["name", "age", "city"]
rows = [
["Alice", "30", "New York"],
["Bob", "25", "Los Angeles"],
]
conn = create_table_from_wikisql(header, rows)
result = execute_sql(conn, "SELECT * FROM data")
conn.close()
assert result is not None
assert len(result) == 2
def test_table_with_special_column_names(self):
"""Test columns with spaces and special characters."""
header = ["Player Name", "Goals-Scored", "Team (2024)"]
rows = [["Messi", "10", "Inter Miami"]]
conn = create_table_from_wikisql(header, rows)
# Column names should be sanitized
result = execute_sql(conn, "SELECT * FROM data")
conn.close()
assert result is not None
assert len(result) == 1
def test_empty_table(self):
"""Test creating a table with no rows."""
header = ["col1", "col2"]
rows = []
conn = create_table_from_wikisql(header, rows)
result = execute_sql(conn, "SELECT * FROM data")
conn.close()
assert result == []
class TestExecuteSQL:
"""Tests for execute_sql function."""
def test_valid_select(self):
"""Test a valid SELECT query."""
header = ["name", "age"]
rows = [["Alice", "30"], ["Bob", "25"]]
conn = create_table_from_wikisql(header, rows)
result = execute_sql(conn, "SELECT name FROM data WHERE age = '30'")
conn.close()
assert result == [("Alice",)]
def test_invalid_sql_returns_none(self):
"""Test that invalid SQL returns None."""
header = ["a"]
rows = [["1"]]
conn = create_table_from_wikisql(header, rows)
result = execute_sql(conn, "INVALID SQL SYNTAX")
conn.close()
assert result is None
def test_aggregation(self):
"""Test COUNT aggregation."""
header = ["category"]
rows = [["A"], ["A"], ["B"]]
conn = create_table_from_wikisql(header, rows)
result = execute_sql(conn, "SELECT COUNT(*) FROM data WHERE category = 'A'")
conn.close()
assert result == [(2,)]
class TestExtractBoxedSQL:
"""Tests for extract_boxed_sql function."""
def test_simple_boxed(self):
"""Test extracting from simple boxed format."""
text = "The answer is \\boxed{SELECT * FROM data}"
result = extract_boxed_sql(text)
assert result == "SELECT * FROM data"
def test_boxed_with_think_tags(self):
"""Test extracting when thinking tags are present."""
text = """<think>
Let me think about this query...
</think>
\\boxed{SELECT name FROM data WHERE age > 25}"""
result = extract_boxed_sql(text)
assert result == "SELECT name FROM data WHERE age > 25"
def test_no_boxed_returns_none(self):
"""Test that missing boxed returns None."""
text = "SELECT * FROM data"
result = extract_boxed_sql(text)
assert result is None
def test_empty_boxed(self):
"""Test empty boxed returns None."""
text = "\\boxed{}"
result = extract_boxed_sql(text)
assert result is None
class TestResultsMatch:
"""Tests for results_match function."""
def test_identical_results(self):
"""Test matching identical results."""
result1 = [("Alice", "30"), ("Bob", "25")]
result2 = [("Alice", "30"), ("Bob", "25")]
assert results_match(result1, result2) is True
def test_different_order(self):
"""Test matching results in different order."""
result1 = [("Alice",), ("Bob",)]
result2 = [("Bob",), ("Alice",)]
assert results_match(result1, result2) is True
def test_case_insensitive(self):
"""Test case-insensitive matching."""
result1 = [("ALICE",)]
result2 = [("alice",)]
assert results_match(result1, result2) is True
def test_different_results(self):
"""Test non-matching results."""
result1 = [("Alice",)]
result2 = [("Bob",)]
assert results_match(result1, result2) is False
def test_none_result(self):
"""Test that None results don't match."""
assert results_match(None, [("a",)]) is False
assert results_match([("a",)], None) is False
class TestNormalizeResult:
"""Tests for normalize_result function."""
def test_normalization(self):
"""Test result normalization."""
result = [("Alice", "30"), ("BOB", "25")]
normalized = normalize_result(result)
# Should be lowercase and sorted
assert normalized == [("25", "bob"), ("30", "alice")] or normalized == [
("alice", "30"),
("bob", "25"),
]
class TestIntegration:
"""Integration tests for the full scoring pipeline."""
def test_correct_sql_scores_1(self):
"""Test that matching SQL results score correctly."""
header = ["name", "age"]
rows = [["Alice", "30"], ["Bob", "25"]]
conn = create_table_from_wikisql(header, rows)
gold_sql = "SELECT name FROM data WHERE age = '30'"
generated_sql = "SELECT name FROM data WHERE age = '30'"
gold_result = execute_sql(conn, gold_sql)
gen_result = execute_sql(conn, generated_sql)
conn.close()
assert results_match(gen_result, gold_result) is True
def test_semantically_equivalent_sql(self):
"""Test that semantically equivalent but differently written SQL matches."""
header = ["name", "age"]
rows = [["Alice", "30"], ["Bob", "25"]]
conn = create_table_from_wikisql(header, rows)
gold_sql = "SELECT name FROM data WHERE age = '30'"
# Different whitespace/formatting but same result
generated_sql = "select NAME from data where AGE='30'"
gold_result = execute_sql(conn, gold_sql)
gen_result = execute_sql(conn, generated_sql)
conn.close()
assert results_match(gen_result, gold_result) is True
def test_wrong_sql_fails(self):
"""Test that incorrect SQL doesn't match."""
header = ["name", "age"]
rows = [["Alice", "30"], ["Bob", "25"]]
conn = create_table_from_wikisql(header, rows)
gold_sql = "SELECT name FROM data WHERE age = '30'"
generated_sql = "SELECT name FROM data WHERE age = '25'"
gold_result = execute_sql(conn, gold_sql)
gen_result = execute_sql(conn, generated_sql)
conn.close()
assert results_match(gen_result, gold_result) is False