Add SQL Query Generation Environment

This commit is contained in:
PLippmann 2026-01-05 16:43:34 +01:00
parent 6d79d4c7ad
commit c927794248
6 changed files with 1427 additions and 0 deletions

View file

@ -0,0 +1,115 @@
# SQL Query Generation Environment
Train LLMs to generate correct SQL queries from natural language questions.
## Overview
This environment uses the [Salesforce/WikiSQL](https://huggingface.co/datasets/Salesforce/wikisql) dataset to train language models on text-to-SQL tasks. Queries are verified by **executing** the generated SQL against in-memory SQLite databases and comparing results to ground truth.
## Dataset
- **Source**: [Salesforce/WikiSQL](https://huggingface.co/datasets/Salesforce/wikisql)
- **Size**: 80,654 examples (train + validation + test)
- **Format**: Natural language questions with table schemas and ground truth SQL
## Usage
### Training Mode (with API Server)
```bash
# Terminal 1: Start the Atropos API
run-api
# Terminal 2: Run the environment
python sql_query_env.py serve --slurm False
```
### Local Testing (without API)
```bash
python sql_query_env.py process --env.data_path_to_save_groups sql_output.jsonl
```
This generates `sql_output.jsonl` and `sql_output.html` for inspection.
### With Local vLLM Server
```bash
python sql_query_env.py process \
--env.data_path_to_save_groups sql_output.jsonl \
--openai.base_url http://localhost:9001/v1 \
--openai.model_name YOUR_MODEL_NAME
```
## Reward Function
| Score | Condition |
|-------|-----------|
| **1.0** | Generated SQL executes and returns same result as gold SQL |
| **-1.0** | SQL fails to execute or returns incorrect result |
When all responses in a group are correct, a length penalty is applied to encourage concise solutions.
## Prompt Format
The model receives a table schema and question:
```
Table: data
Columns: col1, col2, col3
Sample data:
value1 | value2 | value3
Question: What is the value of col1 where col2 equals X?
```
Output should be in boxed format:
```
<think>
[Chain of thought reasoning]
</think>
\boxed{SELECT col1 FROM data WHERE col2 = 'X'}
```
## Unit Tests
```bash
# Run unit tests
python -m pytest test_sql_executor.py -v
```
All 19 tests cover:
- Table creation with special column names
- SQL execution and error handling
- `\boxed{}` extraction patterns
- Result comparison and normalization
- End-to-end scoring integration
## LLM Integration Test
The environment has been verified with Qwen3-8B on an NVIDIA H200:
```bash
# Run integration test with a local vLLM server
python test_integration.py --base_url http://localhost:8000/v1 --model Qwen/Qwen3-8B
```
Test results:
- **40% accuracy** on 10 random WikiSQL examples
- SQL extraction from `\boxed{}` working correctly
- Execution-based scoring producing correct reward signals
## Files
| File | Description |
|------|-------------|
| `sql_query_env.py` | Main environment implementation |
| `sql_executor.py` | SQLite execution and scoring utilities |
| `wikisql_loader.py` | WikiSQL dataset loader (from GitHub) |
| `test_sql_executor.py` | Unit tests (19 tests) |
| `test_integration.py` | LLM integration test |
## Author
Community contribution to Atropos.

View file

@ -0,0 +1,166 @@
"""
SQL Executor Module
Provides safe SQL execution against in-memory SQLite databases.
Used by the SQL Query Environment for reward verification.
"""
import re
import sqlite3
from typing import Any, List, Optional, Tuple
def create_table_from_wikisql(
header: List[str],
rows: List[List[Any]],
types: Optional[List[str]] = None,
) -> sqlite3.Connection:
"""
Create an in-memory SQLite table from WikiSQL format.
Args:
header: List of column names
rows: List of row data (each row is a list of values)
types: Optional list of column types from WikiSQL
Returns:
SQLite connection with the table created and populated
"""
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()
# Use original column names but quoted to handle spaces/special chars
# SQLite allows any string as identifier when quoted with double quotes
cols = ", ".join([f'"{h}" TEXT' for h in header])
cursor.execute(f"CREATE TABLE data ({cols})")
# Also create the "table" alias since some WikiSQL queries use it
# Create as a view pointing to data
try:
cursor.execute("CREATE VIEW 'table' AS SELECT * FROM data")
except sqlite3.OperationalError:
pass # View already exists or name conflict
# Insert rows
if rows:
placeholders = ", ".join(["?" for _ in header])
# Convert all values to strings for consistency
string_rows = [[str(v) if v is not None else "" for v in row] for row in rows]
cursor.executemany(f"INSERT INTO data VALUES ({placeholders})", string_rows)
conn.commit()
return conn
def quote_identifiers_in_sql(sql: str, header: List[str]) -> str:
"""
Quote column names in SQL that contain special characters.
WikiSQL's human_readable SQL often has unquoted columns like:
SELECT State/territory FROM table WHERE ...
This function adds quotes around column names that need them.
"""
result = sql
# Sort by length (longest first) to avoid partial replacements
sorted_headers = sorted(header, key=len, reverse=True)
for col in sorted_headers:
# Skip if already quoted or is a simple identifier
if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", col):
continue
# Escape quotes in column name for regex
col_escaped = re.escape(col)
# Match the column name when not already quoted
# Look for the column name not preceded by " and not followed by "
pattern = rf'(?<!")\b{col_escaped}\b(?!")'
replacement = f'"{col}"'
result = re.sub(pattern, replacement, result)
return result
def execute_sql(conn: sqlite3.Connection, sql: str) -> Optional[List[Tuple]]:
"""
Execute SQL safely and return results.
Args:
conn: SQLite connection
sql: SQL query to execute
Returns:
List of result tuples, or None if execution failed
"""
try:
cursor = conn.cursor()
cursor.execute(sql)
return cursor.fetchall()
except sqlite3.Error:
return None
except Exception:
return None
def extract_boxed_sql(text: str) -> Optional[str]:
"""
Extract SQL from \\boxed{} format in LLM response.
Args:
text: LLM response text
Returns:
Extracted SQL string, or None if not found
"""
# Try to find \boxed{...} pattern
# Handle both \\boxed{} and \boxed{} formats
patterns = [
r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}", # Handles nested braces
r"\\boxed\{(.+?)\}", # Simple pattern
]
for pattern in patterns:
match = re.search(pattern, text, re.DOTALL)
if match:
sql = match.group(1).strip()
if sql:
return sql
return None
def normalize_result(result: List[Tuple]) -> List[Tuple]:
"""
Normalize query results for comparison.
Converts all values to lowercase strings and sorts.
"""
normalized = []
for row in result:
normalized_row = tuple(
str(v).lower().strip() if v is not None else "" for v in row
)
normalized.append(normalized_row)
return sorted(normalized)
def results_match(result1: List[Tuple], result2: List[Tuple]) -> bool:
"""
Compare two SQL query results for equality.
Args:
result1: First result set
result2: Second result set
Returns:
True if results match (order-independent, case-insensitive)
"""
if result1 is None or result2 is None:
return False
norm1 = normalize_result(result1)
norm2 = normalize_result(result2)
return norm1 == norm2

View file

@ -0,0 +1,419 @@
"""
SQL Query Generation Environment for Atropos
Trains LLMs to generate correct SQL queries from natural language questions.
Uses the Salesforce/WikiSQL dataset with execution-based scoring.
"""
import random
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
from sql_executor import (
create_table_from_wikisql,
execute_sql,
extract_boxed_sql,
results_match,
)
from tqdm.asyncio import tqdm_asyncio
from wikisql_loader import load_wikisql_split
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import Item
# System prompt following the established Atropos pattern
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought "
"to deeply consider the problem and deliberate with yourself via systematic "
"reasoning processes to help come to a correct solution prior to answering. "
"You should enclose your thoughts and internal monologue inside <think> </think> "
"tags, and then provide your solution or response to the problem.\n\n"
)
system_prompt += """You are a SQL expert. Given a table schema and a natural language question,
generate a SQL query that answers the question.
You are allocated a maximum of 1024 tokens, please strive to use less.
Provide your SQL query inside \\boxed{} like this: \\boxed{SELECT column FROM table WHERE condition}
Important:
- Use only the columns provided in the table schema
- The table is always named "data"
- Ensure your SQL syntax is valid SQLite
So please end your answer with \\boxed{your SQL query here}"""
class WikiSQLRow(TypedDict):
"""Type definition for a WikiSQL dataset row."""
question: str
header: List[str]
rows: List[List[Any]]
types: List[str]
gold_sql: str
def format_table_schema(
header: List[str], rows: List[List[Any]], max_rows: int = 3
) -> str:
"""Format table schema for the prompt."""
schema = f"Table: data\nColumns: {', '.join(header)}\n"
if rows:
schema += "Sample data:\n"
for row in rows[:max_rows]:
row_str = " | ".join(str(v) for v in row)
schema += f" {row_str}\n"
return schema
class SQLQueryEnv(BaseEnv):
"""
Environment for training LLMs to generate SQL queries.
Uses the Salesforce/WikiSQL dataset and verifies correctness
by executing SQL against in-memory SQLite databases.
"""
name = "sql_query"
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
self.execution_success_buffer = list()
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
"""Initialize default configuration for the environment."""
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
max_token_length=1024,
wandb_name="sql_query",
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log custom metrics to WandB."""
if wandb_metrics is None:
wandb_metrics = {}
# Log percent correct
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
pass
# Log execution success rate
try:
wandb_metrics["train/execution_success"] = sum(
self.execution_success_buffer
) / len(self.execution_success_buffer)
except ZeroDivisionError:
pass
self.percent_correct_buffer = list()
self.execution_success_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
await super().wandb_log(wandb_metrics)
async def setup(self):
"""Load the WikiSQL dataset and prepare train/test splits."""
# Load WikiSQL dataset directly from GitHub source
print("Loading WikiSQL training data...")
self.train = load_wikisql_split("train")
print(f"Loaded {len(self.train)} training examples")
print("Loading WikiSQL test data...")
self.test = load_wikisql_split("test")
print(f"Loaded {len(self.test)} test examples")
random.shuffle(self.train)
self.iter = 0
def save_checkpoint(self, step, data=None):
"""Save checkpoint with iteration state."""
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
def _score_sql(
self,
generated_sql: str,
gold_sql: str,
header: List[str],
rows: List[List[Any]],
) -> Tuple[float, bool]:
"""
Score SQL by execution comparison.
Returns:
Tuple of (score, execution_success)
"""
if not generated_sql:
return -1.0, False
# Create in-memory table
try:
conn = create_table_from_wikisql(header, rows)
except Exception:
return -1.0, False
# Execute gold SQL
gold_result = execute_sql(conn, gold_sql)
if gold_result is None:
conn.close()
return -1.0, False
# Execute generated SQL
gen_result = execute_sql(conn, generated_sql)
conn.close()
if gen_result is None:
return -1.0, False
# Compare results
if results_match(gen_result, gold_result):
return 1.0, True
else:
return -1.0, True
async def rollout_and_score_eval(
self, question: str, gold_sql: str, header: List[str], rows: List[List[Any]]
) -> dict:
"""Rollout and score a single evaluation item."""
table_schema = format_table_schema(header, rows)
user_content = f"{table_schema}\nQuestion: {question}"
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.6,
)
response_content = completion.choices[0].message.content
# Extract and score generated SQL
generated_sql = extract_boxed_sql(response_content)
score, exec_success = self._score_sql(generated_sql, gold_sql, header, rows)
correct = score == 1.0
sample = {
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
{"role": "assistant", "content": response_content},
],
"question": question,
"gold_sql": gold_sql,
"generated_sql": generated_sql,
"score": 1 if correct else 0,
"correct": correct,
"execution_success": exec_success,
"finish_reason": completion.choices[0].finish_reason,
}
return {"score": 1 if correct else 0, "sample": sample}
async def evaluate(self, *args, **kwargs):
"""Run evaluation on test set."""
import time
start_time = time.time()
eval_tasks = []
# Limit eval to first 500 items for efficiency
for item in self.test[:500]:
eval_tasks.append(
self.rollout_and_score_eval(
item["question"], item["gold_sql"], item["header"], item["rows"]
)
)
results = await tqdm_asyncio.gather(*eval_tasks)
scores = [result["score"] for result in results]
samples = [result["sample"] for result in results]
percent_correct = sum(scores) / len(scores) if scores else 0
end_time = time.time()
self.eval_metrics.append(("eval/percent_correct", percent_correct))
eval_metrics = {
"eval/percent_correct": percent_correct,
}
await self.evaluate_log(
metrics=eval_metrics,
samples=samples,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": 0.6,
"max_tokens": self.config.max_token_length,
},
)
async def collect_trajectories(
self, item: WikiSQLRow
) -> Tuple[ScoredDataGroup, list[Item]]:
"""Generate SQL queries for a given question."""
table_schema = format_table_schema(item["header"], item["rows"])
user_content = f"{table_schema}\nQuestion: {item['question']}"
user_message = {"role": "user", "content": user_content}
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
chat_completions = await managed.chat_completion(
messages=[{"role": "system", "content": system_prompt}, user_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=1.0,
)
state = managed.get_state()
nodes = state["nodes"]
to_score = list()
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
messages = (
{"role": "system", "content": system_prompt},
user_message,
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append(
{
"messages": messages,
"gold_sql": item["gold_sql"],
"header": item["header"],
"rows": item["rows"],
"finish_reason": chat_completion.finish_reason,
"tokens": nodes[i].tokens,
"masks": nodes[i].masked_tokens,
"logprobs": nodes[i].logprobs,
}
)
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog
async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
"""Score generated SQL queries by execution comparison."""
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
# Get table info from first item
gold_sql = rollout_group_data[0]["gold_sql"]
header = rollout_group_data[0]["header"]
rows = rollout_group_data[0]["rows"]
random.shuffle(rollout_group_data)
for item in rollout_group_data:
response_content = item["messages"][-1]["content"]
# Extract SQL from response
generated_sql = extract_boxed_sql(response_content)
# Score by execution
reward, exec_success = self._score_sql(
generated_sql, gold_sql, header, rows
)
self.execution_success_buffer.append(1 if exec_success else 0)
tokens = item["tokens"]
masks = item["masks"]
logprobs = item["logprobs"]
# Remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["inference_logprobs"].append(logprobs)
scores["scores"].append(reward)
if len(scores["tokens"]) >= self.config.group_size:
break
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
# Check if all scores are the same
if all([score == 1 for score in scores["scores"]]):
# Apply length penalty when all are correct
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) == 0:
return None
max_allowed_length = self.config.max_token_length
length_threshold = max_allowed_length * 0.5
scores["scores"] = []
for length in token_lengths:
if length <= length_threshold:
scores["scores"].append(1.0)
else:
percentage_of_range = (length - length_threshold) / (
max_allowed_length - length_threshold
)
percentage_of_range = min(percentage_of_range, 1.0)
scores["scores"].append(1.0 - percentage_of_range)
if all([scores["scores"][0] == score for score in scores["scores"]]):
return None # If all the same, return None
return scores
async def get_next_item(self) -> WikiSQLRow:
"""Get the next training item."""
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
return next_item
if __name__ == "__main__":
SQLQueryEnv.cli()

View file

@ -0,0 +1,276 @@
#!/usr/bin/env python3
"""
Integration test for SQL Query Environment that works with OpenAI-compatible APIs.
This test verifies:
1. WikiSQL dataset loading
2. SQL generation from LLM
3. SQL extraction from \boxed{}
4. SQL execution and result comparison
5. Scoring logic
"""
import asyncio
import json
from typing import Any, List
import openai
# Import local modules
from sql_executor import (
create_table_from_wikisql,
execute_sql,
extract_boxed_sql,
quote_identifiers_in_sql,
results_match,
)
from wikisql_loader import load_wikisql_split
# System prompt from the environment
SYSTEM_PROMPT = (
"You are a deep thinking AI, you may use extremely long chains of thought "
"to deeply consider the problem and deliberate with yourself via systematic "
"reasoning processes to help come to a correct solution prior to answering. "
"You should enclose your thoughts and internal monologue inside <think> </think> "
"tags, and then provide your solution or response to the problem.\n\n"
"You are a SQL expert. Given a table schema and a natural language question, "
"generate a SQL query that answers the question.\n\n"
"You are allocated a maximum of 1024 tokens, please strive to use less.\n\n"
"Provide your SQL query inside \\boxed{} like this: "
"\\boxed{SELECT column FROM table WHERE condition}\n\n"
"Important:\n"
"- Use only the columns provided in the table schema\n"
'- The table is always named "data"\n'
"- Ensure your SQL syntax is valid SQLite\n\n"
"So please end your answer with \\boxed{your SQL query here}"
)
def format_table_schema(
header: List[str], rows: List[List[Any]], max_rows: int = 3
) -> str:
"""Format table schema for the prompt."""
schema = f"Table: data\nColumns: {', '.join(header)}\n"
if rows:
schema += "Sample data:\n"
for row in rows[:max_rows]:
row_str = " | ".join(str(v) for v in row)
schema += f" {row_str}\n"
return schema
def score_sql(
generated_sql: str,
gold_sql: str,
header: List[str],
rows: List[List[Any]],
) -> dict:
"""Score SQL by execution comparison."""
result = {
"generated_sql": generated_sql,
"gold_sql": gold_sql,
"score": -1.0,
"execution_success": False,
"gen_result": None,
"gold_result": None,
"error": None,
}
if not generated_sql:
result["error"] = "No SQL extracted from response"
return result
# Create in-memory table
try:
conn = create_table_from_wikisql(header, rows)
except Exception as e:
result["error"] = f"Table creation failed: {e}"
return result
# Quote identifiers in gold SQL that need quoting
quoted_gold_sql = quote_identifiers_in_sql(gold_sql, header)
result["quoted_gold_sql"] = quoted_gold_sql
# Execute gold SQL
gold_result = execute_sql(conn, quoted_gold_sql)
if gold_result is None:
conn.close()
result["error"] = f"Gold SQL execution failed: {quoted_gold_sql}"
return result
result["gold_result"] = str(gold_result)
# Execute generated SQL
gen_result = execute_sql(conn, generated_sql)
conn.close()
if gen_result is None:
result["error"] = "Generated SQL execution failed"
return result
result["gen_result"] = str(gen_result)
result["execution_success"] = True
# Compare results
if results_match(gen_result, gold_result):
result["score"] = 1.0
else:
result["error"] = "Results do not match"
return result
async def test_single_item(client, model_name: str, item: dict, item_idx: int) -> dict:
"""Test a single WikiSQL item."""
table_schema = format_table_schema(item["header"], item["rows"])
user_content = f"{table_schema}\nQuestion: {item['question']}"
try:
response = await client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
],
max_tokens=1024,
temperature=0.6,
)
response_content = response.choices[0].message.content
# Extract SQL
generated_sql = extract_boxed_sql(response_content)
# Score
score_result = score_sql(
generated_sql, item["gold_sql"], item["header"], item["rows"]
)
return {
"item_idx": item_idx,
"question": item["question"],
"response": (
response_content[:500] + "..."
if len(response_content) > 500
else response_content
),
**score_result,
}
except Exception as e:
return {
"item_idx": item_idx,
"question": item["question"],
"error": str(e),
"score": -1.0,
}
async def run_integration_test(
base_url: str,
model_name: str,
api_key: str = "x",
num_samples: int = 10,
):
"""Run the integration test."""
print(f"\n{'='*60}")
print("SQL Query Environment Integration Test")
print(f"{'='*60}")
print(f"Server: {base_url}")
print(f"Model: {model_name}")
print(f"Samples: {num_samples}")
print()
# Load dataset
print("Loading WikiSQL training data...")
train_data = load_wikisql_split("train")
print(f"Loaded {len(train_data)} training examples")
# Initialize OpenAI client
client = openai.AsyncClient(
base_url=base_url,
api_key=api_key,
timeout=120.0,
)
# Run tests
print(f"\nTesting {num_samples} samples...\n")
results = []
for i in range(min(num_samples, len(train_data))):
item = train_data[i]
print(f"[{i+1}/{num_samples}] Testing: {item['question'][:60]}...")
result = await test_single_item(client, model_name, item, i)
results.append(result)
# Print result
if result["score"] == 1.0:
print(f" ✓ CORRECT - Generated: {result.get('generated_sql', 'N/A')[:80]}")
else:
print(f" ✗ INCORRECT - {result.get('error', 'Unknown error')}")
if result.get("generated_sql"):
print(f" Generated: {result['generated_sql'][:80]}")
print(f" Gold: {result.get('gold_sql', 'N/A')[:80]}")
# Summary
print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
correct = sum(1 for r in results if r["score"] == 1.0)
executed = sum(1 for r in results if r.get("execution_success", False))
print(f"Correct: {correct}/{num_samples} ({100*correct/num_samples:.1f}%)")
print(
f"Execution Success: {executed}/{num_samples} ({100*executed/num_samples:.1f}%)"
)
# Save results
output_file = "integration_test_results.json"
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nDetailed results saved to: {output_file}")
return results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="SQL Query Environment Integration Test"
)
parser.add_argument(
"--base_url",
type=str,
default="http://localhost:8000/v1",
help="Base URL for OpenAI-compatible API",
)
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen3-8B",
help="Model name",
)
parser.add_argument(
"--api_key",
type=str,
default="x",
help="API key",
)
parser.add_argument(
"--num_samples",
type=int,
default=10,
help="Number of samples to test",
)
args = parser.parse_args()
asyncio.run(
run_integration_test(
base_url=args.base_url,
model_name=args.model,
api_key=args.api_key,
num_samples=args.num_samples,
)
)

View file

@ -0,0 +1,224 @@
"""
Unit tests for SQL executor module.
"""
from sql_executor import (
create_table_from_wikisql,
execute_sql,
extract_boxed_sql,
normalize_result,
results_match,
)
class TestCreateTableFromWikiSQL:
"""Tests for create_table_from_wikisql function."""
def test_basic_table_creation(self):
"""Test creating a simple table."""
header = ["name", "age", "city"]
rows = [
["Alice", "30", "New York"],
["Bob", "25", "Los Angeles"],
]
conn = create_table_from_wikisql(header, rows)
result = execute_sql(conn, "SELECT * FROM data")
conn.close()
assert result is not None
assert len(result) == 2
def test_table_with_special_column_names(self):
"""Test columns with spaces and special characters."""
header = ["Player Name", "Goals-Scored", "Team (2024)"]
rows = [["Messi", "10", "Inter Miami"]]
conn = create_table_from_wikisql(header, rows)
# Column names should be sanitized
result = execute_sql(conn, "SELECT * FROM data")
conn.close()
assert result is not None
assert len(result) == 1
def test_empty_table(self):
"""Test creating a table with no rows."""
header = ["col1", "col2"]
rows = []
conn = create_table_from_wikisql(header, rows)
result = execute_sql(conn, "SELECT * FROM data")
conn.close()
assert result == []
class TestExecuteSQL:
"""Tests for execute_sql function."""
def test_valid_select(self):
"""Test a valid SELECT query."""
header = ["name", "age"]
rows = [["Alice", "30"], ["Bob", "25"]]
conn = create_table_from_wikisql(header, rows)
result = execute_sql(conn, "SELECT name FROM data WHERE age = '30'")
conn.close()
assert result == [("Alice",)]
def test_invalid_sql_returns_none(self):
"""Test that invalid SQL returns None."""
header = ["a"]
rows = [["1"]]
conn = create_table_from_wikisql(header, rows)
result = execute_sql(conn, "INVALID SQL SYNTAX")
conn.close()
assert result is None
def test_aggregation(self):
"""Test COUNT aggregation."""
header = ["category"]
rows = [["A"], ["A"], ["B"]]
conn = create_table_from_wikisql(header, rows)
result = execute_sql(conn, "SELECT COUNT(*) FROM data WHERE category = 'A'")
conn.close()
assert result == [(2,)]
class TestExtractBoxedSQL:
"""Tests for extract_boxed_sql function."""
def test_simple_boxed(self):
"""Test extracting from simple boxed format."""
text = "The answer is \\boxed{SELECT * FROM data}"
result = extract_boxed_sql(text)
assert result == "SELECT * FROM data"
def test_boxed_with_think_tags(self):
"""Test extracting when thinking tags are present."""
text = """<think>
Let me think about this query...
</think>
\\boxed{SELECT name FROM data WHERE age > 25}"""
result = extract_boxed_sql(text)
assert result == "SELECT name FROM data WHERE age > 25"
def test_no_boxed_returns_none(self):
"""Test that missing boxed returns None."""
text = "SELECT * FROM data"
result = extract_boxed_sql(text)
assert result is None
def test_empty_boxed(self):
"""Test empty boxed returns None."""
text = "\\boxed{}"
result = extract_boxed_sql(text)
assert result is None
class TestResultsMatch:
"""Tests for results_match function."""
def test_identical_results(self):
"""Test matching identical results."""
result1 = [("Alice", "30"), ("Bob", "25")]
result2 = [("Alice", "30"), ("Bob", "25")]
assert results_match(result1, result2) is True
def test_different_order(self):
"""Test matching results in different order."""
result1 = [("Alice",), ("Bob",)]
result2 = [("Bob",), ("Alice",)]
assert results_match(result1, result2) is True
def test_case_insensitive(self):
"""Test case-insensitive matching."""
result1 = [("ALICE",)]
result2 = [("alice",)]
assert results_match(result1, result2) is True
def test_different_results(self):
"""Test non-matching results."""
result1 = [("Alice",)]
result2 = [("Bob",)]
assert results_match(result1, result2) is False
def test_none_result(self):
"""Test that None results don't match."""
assert results_match(None, [("a",)]) is False
assert results_match([("a",)], None) is False
class TestNormalizeResult:
"""Tests for normalize_result function."""
def test_normalization(self):
"""Test result normalization."""
result = [("Alice", "30"), ("BOB", "25")]
normalized = normalize_result(result)
# Should be lowercase and sorted
assert normalized == [("25", "bob"), ("30", "alice")] or normalized == [
("alice", "30"),
("bob", "25"),
]
class TestIntegration:
"""Integration tests for the full scoring pipeline."""
def test_correct_sql_scores_1(self):
"""Test that matching SQL results score correctly."""
header = ["name", "age"]
rows = [["Alice", "30"], ["Bob", "25"]]
conn = create_table_from_wikisql(header, rows)
gold_sql = "SELECT name FROM data WHERE age = '30'"
generated_sql = "SELECT name FROM data WHERE age = '30'"
gold_result = execute_sql(conn, gold_sql)
gen_result = execute_sql(conn, generated_sql)
conn.close()
assert results_match(gen_result, gold_result) is True
def test_semantically_equivalent_sql(self):
"""Test that semantically equivalent but differently written SQL matches."""
header = ["name", "age"]
rows = [["Alice", "30"], ["Bob", "25"]]
conn = create_table_from_wikisql(header, rows)
gold_sql = "SELECT name FROM data WHERE age = '30'"
# Different whitespace/formatting but same result
generated_sql = "select NAME from data where AGE='30'"
gold_result = execute_sql(conn, gold_sql)
gen_result = execute_sql(conn, generated_sql)
conn.close()
assert results_match(gen_result, gold_result) is True
def test_wrong_sql_fails(self):
"""Test that incorrect SQL doesn't match."""
header = ["name", "age"]
rows = [["Alice", "30"], ["Bob", "25"]]
conn = create_table_from_wikisql(header, rows)
gold_sql = "SELECT name FROM data WHERE age = '30'"
generated_sql = "SELECT name FROM data WHERE age = '25'"
gold_result = execute_sql(conn, gold_sql)
gen_result = execute_sql(conn, generated_sql)
conn.close()
assert results_match(gen_result, gold_result) is False

View file

@ -0,0 +1,227 @@
"""
WikiSQL Data Loader
Downloads and loads the WikiSQL dataset directly from the Salesforce GitHub repository.
This module provides functionality to fetch the dataset, extract it, and load it
in a format compatible with the SQL Query Environment.
"""
import json
import os
import tarfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from urllib.request import urlretrieve
# WikiSQL data source
WIKISQL_DATA_URL = "https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2"
# SQL operators from WikiSQL lib/query.py
AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
COND_OPS = ["=", ">", "<", "OP"]
def get_cache_dir() -> Path:
"""Get the cache directory for WikiSQL data."""
cache_dir = Path(
os.environ.get("WIKISQL_CACHE_DIR", Path.home() / ".cache" / "wikisql")
)
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
def download_wikisql(cache_dir: Optional[Path] = None, force: bool = False) -> Path:
"""
Download the WikiSQL dataset from GitHub.
Args:
cache_dir: Directory to cache the downloaded data
force: Force re-download even if already cached
Returns:
Path to the extracted data directory
"""
if cache_dir is None:
cache_dir = get_cache_dir()
data_dir = cache_dir / "data"
tar_file = cache_dir / "data.tar.bz2"
# Check if already extracted
if data_dir.exists() and not force:
expected_files = [
"train.jsonl",
"dev.jsonl",
"test.jsonl",
"train.db",
"dev.db",
"test.db",
]
if all((data_dir / f).exists() for f in expected_files):
return data_dir
# Download if needed
if not tar_file.exists() or force:
print(f"Downloading WikiSQL dataset from {WIKISQL_DATA_URL}...")
urlretrieve(WIKISQL_DATA_URL, tar_file)
print(f"Downloaded to {tar_file}")
# Extract
print(f"Extracting to {cache_dir}...")
with tarfile.open(tar_file, "r:bz2") as tar:
tar.extractall(cache_dir)
print("Extraction complete.")
return data_dir
def build_human_readable_sql(
table_header: List[str],
sel: int,
agg: int,
conds: List[Tuple[int, int, str]],
) -> str:
"""
Build a human-readable SQL query from WikiSQL format.
Args:
table_header: List of column names
sel: Index of selected column
agg: Index of aggregation operator
conds: List of (column_index, operator_index, condition) tuples
Returns:
Human-readable SQL query string
"""
# Build SELECT clause
sel_col = table_header[sel] if sel < len(table_header) else f"col{sel}"
if agg > 0:
select_clause = f"{AGG_OPS[agg]}({sel_col})"
else:
select_clause = sel_col
sql = f"SELECT {select_clause} FROM data"
# Build WHERE clause
if conds:
where_parts = []
for col_idx, op_idx, condition in conds:
col_name = (
table_header[col_idx]
if col_idx < len(table_header)
else f"col{col_idx}"
)
op = COND_OPS[op_idx] if op_idx < len(COND_OPS) else "="
# Quote string conditions
if isinstance(condition, str):
where_parts.append(f"{col_name} {op} '{condition}'")
else:
where_parts.append(f"{col_name} {op} {condition}")
sql += " WHERE " + " AND ".join(where_parts)
return sql
def load_wikisql_split(
split: str = "train",
cache_dir: Optional[Path] = None,
) -> List[Dict[str, Any]]:
"""
Load a WikiSQL dataset split.
Args:
split: One of 'train', 'dev', or 'test'
cache_dir: Directory containing the WikiSQL data
Returns:
List of dictionaries with question, table info, and gold SQL
"""
if cache_dir is None:
data_dir = download_wikisql()
else:
data_dir = cache_dir
jsonl_file = data_dir / f"{split}.jsonl"
tables_file = data_dir / f"{split}.tables.jsonl"
if not jsonl_file.exists():
raise FileNotFoundError(f"WikiSQL {split} JSONL file not found: {jsonl_file}")
if not tables_file.exists():
raise FileNotFoundError(f"WikiSQL {split} tables file not found: {tables_file}")
# Load all tables from tables.jsonl (has real column headers)
tables_cache = {}
with open(tables_file, "r", encoding="utf-8") as f:
for line in f:
table = json.loads(line.strip())
tables_cache[table["id"]] = {
"header": table["header"],
"rows": table["rows"],
"types": table.get("types", ["text"] * len(table["header"])),
}
print(f"Loaded {len(tables_cache)} tables for {split} split")
# Load questions and build dataset
dataset = []
skipped = 0
with open(jsonl_file, "r", encoding="utf-8") as f:
for line in f:
item = json.loads(line.strip())
table_info = tables_cache.get(item["table_id"])
if not table_info or not table_info["header"]:
skipped += 1
continue
# Build human-readable SQL
conds = [(c[0], c[1], c[2]) for c in item["sql"]["conds"]]
gold_sql = build_human_readable_sql(
table_info["header"],
item["sql"]["sel"],
item["sql"]["agg"],
conds,
)
dataset.append(
{
"question": item["question"],
"header": table_info["header"],
"rows": table_info["rows"],
"types": table_info["types"],
"gold_sql": gold_sql,
"table_id": item["table_id"],
}
)
if skipped > 0:
print(f"Skipped {skipped} examples with missing tables")
return dataset
def load_wikisql(cache_dir: Optional[Path] = None) -> Dict[str, List[Dict[str, Any]]]:
"""
Load the full WikiSQL dataset.
Returns:
Dictionary with 'train', 'dev', and 'test' splits
"""
return {
"train": load_wikisql_split("train", cache_dir),
"dev": load_wikisql_split("dev", cache_dir),
"test": load_wikisql_split("test", cache_dir),
}
if __name__ == "__main__":
# Test the loader
print("Testing WikiSQL loader...")
train = load_wikisql_split("train")
print(f"Loaded {len(train)} training examples")
if train:
print("\nFirst example:")
print(f" Question: {train[0]['question']}")
print(f" Header: {train[0]['header']}")
print(f" Gold SQL: {train[0]['gold_sql']}")
print(f" Rows (first 2): {train[0]['rows'][:2]}")