mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Add SQL Query Generation Environment
This commit is contained in:
parent
6d79d4c7ad
commit
c927794248
6 changed files with 1427 additions and 0 deletions
115
environments/community/sql_query_env/README.md
Normal file
115
environments/community/sql_query_env/README.md
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
# SQL Query Generation Environment
|
||||
|
||||
Train LLMs to generate correct SQL queries from natural language questions.
|
||||
|
||||
## Overview
|
||||
|
||||
This environment uses the [Salesforce/WikiSQL](https://huggingface.co/datasets/Salesforce/wikisql) dataset to train language models on text-to-SQL tasks. Queries are verified by **executing** the generated SQL against in-memory SQLite databases and comparing results to ground truth.
|
||||
|
||||
## Dataset
|
||||
|
||||
- **Source**: [Salesforce/WikiSQL](https://huggingface.co/datasets/Salesforce/wikisql)
|
||||
- **Size**: 80,654 examples (train + validation + test)
|
||||
- **Format**: Natural language questions with table schemas and ground truth SQL
|
||||
|
||||
## Usage
|
||||
|
||||
### Training Mode (with API Server)
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start the Atropos API
|
||||
run-api
|
||||
|
||||
# Terminal 2: Run the environment
|
||||
python sql_query_env.py serve --slurm False
|
||||
```
|
||||
|
||||
### Local Testing (without API)
|
||||
|
||||
```bash
|
||||
python sql_query_env.py process --env.data_path_to_save_groups sql_output.jsonl
|
||||
```
|
||||
|
||||
This generates `sql_output.jsonl` and `sql_output.html` for inspection.
|
||||
|
||||
### With Local vLLM Server
|
||||
|
||||
```bash
|
||||
python sql_query_env.py process \
|
||||
--env.data_path_to_save_groups sql_output.jsonl \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.model_name YOUR_MODEL_NAME
|
||||
```
|
||||
|
||||
## Reward Function
|
||||
|
||||
| Score | Condition |
|
||||
|-------|-----------|
|
||||
| **1.0** | Generated SQL executes and returns same result as gold SQL |
|
||||
| **-1.0** | SQL fails to execute or returns incorrect result |
|
||||
|
||||
When all responses in a group are correct, a length penalty is applied to encourage concise solutions.
|
||||
|
||||
## Prompt Format
|
||||
|
||||
The model receives a table schema and question:
|
||||
|
||||
```
|
||||
Table: data
|
||||
Columns: col1, col2, col3
|
||||
Sample data:
|
||||
value1 | value2 | value3
|
||||
|
||||
Question: What is the value of col1 where col2 equals X?
|
||||
```
|
||||
|
||||
Output should be in boxed format:
|
||||
```
|
||||
<think>
|
||||
[Chain of thought reasoning]
|
||||
</think>
|
||||
|
||||
\boxed{SELECT col1 FROM data WHERE col2 = 'X'}
|
||||
```
|
||||
|
||||
## Unit Tests
|
||||
|
||||
```bash
|
||||
# Run unit tests
|
||||
python -m pytest test_sql_executor.py -v
|
||||
```
|
||||
|
||||
All 19 tests cover:
|
||||
- Table creation with special column names
|
||||
- SQL execution and error handling
|
||||
- `\boxed{}` extraction patterns
|
||||
- Result comparison and normalization
|
||||
- End-to-end scoring integration
|
||||
|
||||
## LLM Integration Test
|
||||
|
||||
The environment has been verified with Qwen3-8B on an NVIDIA H200:
|
||||
|
||||
```bash
|
||||
# Run integration test with a local vLLM server
|
||||
python test_integration.py --base_url http://localhost:8000/v1 --model Qwen/Qwen3-8B
|
||||
```
|
||||
|
||||
Test results:
|
||||
- **40% accuracy** on 10 random WikiSQL examples
|
||||
- SQL extraction from `\boxed{}` working correctly
|
||||
- Execution-based scoring producing correct reward signals
|
||||
|
||||
## Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `sql_query_env.py` | Main environment implementation |
|
||||
| `sql_executor.py` | SQLite execution and scoring utilities |
|
||||
| `wikisql_loader.py` | WikiSQL dataset loader (from GitHub) |
|
||||
| `test_sql_executor.py` | Unit tests (19 tests) |
|
||||
| `test_integration.py` | LLM integration test |
|
||||
|
||||
## Author
|
||||
|
||||
Community contribution to Atropos.
|
||||
166
environments/community/sql_query_env/sql_executor.py
Normal file
166
environments/community/sql_query_env/sql_executor.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
"""
|
||||
SQL Executor Module
|
||||
|
||||
Provides safe SQL execution against in-memory SQLite databases.
|
||||
Used by the SQL Query Environment for reward verification.
|
||||
"""
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
|
||||
def create_table_from_wikisql(
|
||||
header: List[str],
|
||||
rows: List[List[Any]],
|
||||
types: Optional[List[str]] = None,
|
||||
) -> sqlite3.Connection:
|
||||
"""
|
||||
Create an in-memory SQLite table from WikiSQL format.
|
||||
|
||||
Args:
|
||||
header: List of column names
|
||||
rows: List of row data (each row is a list of values)
|
||||
types: Optional list of column types from WikiSQL
|
||||
|
||||
Returns:
|
||||
SQLite connection with the table created and populated
|
||||
"""
|
||||
conn = sqlite3.connect(":memory:")
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Use original column names but quoted to handle spaces/special chars
|
||||
# SQLite allows any string as identifier when quoted with double quotes
|
||||
cols = ", ".join([f'"{h}" TEXT' for h in header])
|
||||
cursor.execute(f"CREATE TABLE data ({cols})")
|
||||
|
||||
# Also create the "table" alias since some WikiSQL queries use it
|
||||
# Create as a view pointing to data
|
||||
try:
|
||||
cursor.execute("CREATE VIEW 'table' AS SELECT * FROM data")
|
||||
except sqlite3.OperationalError:
|
||||
pass # View already exists or name conflict
|
||||
|
||||
# Insert rows
|
||||
if rows:
|
||||
placeholders = ", ".join(["?" for _ in header])
|
||||
# Convert all values to strings for consistency
|
||||
string_rows = [[str(v) if v is not None else "" for v in row] for row in rows]
|
||||
cursor.executemany(f"INSERT INTO data VALUES ({placeholders})", string_rows)
|
||||
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
def quote_identifiers_in_sql(sql: str, header: List[str]) -> str:
|
||||
"""
|
||||
Quote column names in SQL that contain special characters.
|
||||
|
||||
WikiSQL's human_readable SQL often has unquoted columns like:
|
||||
SELECT State/territory FROM table WHERE ...
|
||||
|
||||
This function adds quotes around column names that need them.
|
||||
"""
|
||||
result = sql
|
||||
# Sort by length (longest first) to avoid partial replacements
|
||||
sorted_headers = sorted(header, key=len, reverse=True)
|
||||
|
||||
for col in sorted_headers:
|
||||
# Skip if already quoted or is a simple identifier
|
||||
if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", col):
|
||||
continue
|
||||
|
||||
# Escape quotes in column name for regex
|
||||
col_escaped = re.escape(col)
|
||||
|
||||
# Match the column name when not already quoted
|
||||
# Look for the column name not preceded by " and not followed by "
|
||||
pattern = rf'(?<!")\b{col_escaped}\b(?!")'
|
||||
replacement = f'"{col}"'
|
||||
|
||||
result = re.sub(pattern, replacement, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def execute_sql(conn: sqlite3.Connection, sql: str) -> Optional[List[Tuple]]:
|
||||
"""
|
||||
Execute SQL safely and return results.
|
||||
|
||||
Args:
|
||||
conn: SQLite connection
|
||||
sql: SQL query to execute
|
||||
|
||||
Returns:
|
||||
List of result tuples, or None if execution failed
|
||||
"""
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql)
|
||||
return cursor.fetchall()
|
||||
except sqlite3.Error:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def extract_boxed_sql(text: str) -> Optional[str]:
|
||||
"""
|
||||
Extract SQL from \\boxed{} format in LLM response.
|
||||
|
||||
Args:
|
||||
text: LLM response text
|
||||
|
||||
Returns:
|
||||
Extracted SQL string, or None if not found
|
||||
"""
|
||||
# Try to find \boxed{...} pattern
|
||||
# Handle both \\boxed{} and \boxed{} formats
|
||||
patterns = [
|
||||
r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}", # Handles nested braces
|
||||
r"\\boxed\{(.+?)\}", # Simple pattern
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
sql = match.group(1).strip()
|
||||
if sql:
|
||||
return sql
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def normalize_result(result: List[Tuple]) -> List[Tuple]:
|
||||
"""
|
||||
Normalize query results for comparison.
|
||||
|
||||
Converts all values to lowercase strings and sorts.
|
||||
"""
|
||||
normalized = []
|
||||
for row in result:
|
||||
normalized_row = tuple(
|
||||
str(v).lower().strip() if v is not None else "" for v in row
|
||||
)
|
||||
normalized.append(normalized_row)
|
||||
return sorted(normalized)
|
||||
|
||||
|
||||
def results_match(result1: List[Tuple], result2: List[Tuple]) -> bool:
|
||||
"""
|
||||
Compare two SQL query results for equality.
|
||||
|
||||
Args:
|
||||
result1: First result set
|
||||
result2: Second result set
|
||||
|
||||
Returns:
|
||||
True if results match (order-independent, case-insensitive)
|
||||
"""
|
||||
if result1 is None or result2 is None:
|
||||
return False
|
||||
|
||||
norm1 = normalize_result(result1)
|
||||
norm2 = normalize_result(result2)
|
||||
|
||||
return norm1 == norm2
|
||||
419
environments/community/sql_query_env/sql_query_env.py
Normal file
419
environments/community/sql_query_env/sql_query_env.py
Normal file
|
|
@ -0,0 +1,419 @@
|
|||
"""
|
||||
SQL Query Generation Environment for Atropos
|
||||
|
||||
Trains LLMs to generate correct SQL queries from natural language questions.
|
||||
Uses the Salesforce/WikiSQL dataset with execution-based scoring.
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from sql_executor import (
|
||||
create_table_from_wikisql,
|
||||
execute_sql,
|
||||
extract_boxed_sql,
|
||||
results_match,
|
||||
)
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from wikisql_loader import load_wikisql_split
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
# System prompt following the established Atropos pattern
|
||||
system_prompt = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought "
|
||||
"to deeply consider the problem and deliberate with yourself via systematic "
|
||||
"reasoning processes to help come to a correct solution prior to answering. "
|
||||
"You should enclose your thoughts and internal monologue inside <think> </think> "
|
||||
"tags, and then provide your solution or response to the problem.\n\n"
|
||||
)
|
||||
|
||||
system_prompt += """You are a SQL expert. Given a table schema and a natural language question,
|
||||
generate a SQL query that answers the question.
|
||||
|
||||
You are allocated a maximum of 1024 tokens, please strive to use less.
|
||||
|
||||
Provide your SQL query inside \\boxed{} like this: \\boxed{SELECT column FROM table WHERE condition}
|
||||
|
||||
Important:
|
||||
- Use only the columns provided in the table schema
|
||||
- The table is always named "data"
|
||||
- Ensure your SQL syntax is valid SQLite
|
||||
|
||||
So please end your answer with \\boxed{your SQL query here}"""
|
||||
|
||||
|
||||
class WikiSQLRow(TypedDict):
|
||||
"""Type definition for a WikiSQL dataset row."""
|
||||
|
||||
question: str
|
||||
header: List[str]
|
||||
rows: List[List[Any]]
|
||||
types: List[str]
|
||||
gold_sql: str
|
||||
|
||||
|
||||
def format_table_schema(
|
||||
header: List[str], rows: List[List[Any]], max_rows: int = 3
|
||||
) -> str:
|
||||
"""Format table schema for the prompt."""
|
||||
schema = f"Table: data\nColumns: {', '.join(header)}\n"
|
||||
if rows:
|
||||
schema += "Sample data:\n"
|
||||
for row in rows[:max_rows]:
|
||||
row_str = " | ".join(str(v) for v in row)
|
||||
schema += f" {row_str}\n"
|
||||
return schema
|
||||
|
||||
|
||||
class SQLQueryEnv(BaseEnv):
|
||||
"""
|
||||
Environment for training LLMs to generate SQL queries.
|
||||
|
||||
Uses the Salesforce/WikiSQL dataset and verifies correctness
|
||||
by executing SQL against in-memory SQLite databases.
|
||||
"""
|
||||
|
||||
name = "sql_query"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
self.execution_success_buffer = list()
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
"""Initialize default configuration for the environment."""
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=1024,
|
||||
wandb_name="sql_query",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log custom metrics to WandB."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Log percent correct
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
|
||||
# Log execution success rate
|
||||
try:
|
||||
wandb_metrics["train/execution_success"] = sum(
|
||||
self.execution_success_buffer
|
||||
) / len(self.execution_success_buffer)
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
self.execution_success_buffer = list()
|
||||
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
"""Load the WikiSQL dataset and prepare train/test splits."""
|
||||
# Load WikiSQL dataset directly from GitHub source
|
||||
print("Loading WikiSQL training data...")
|
||||
self.train = load_wikisql_split("train")
|
||||
print(f"Loaded {len(self.train)} training examples")
|
||||
|
||||
print("Loading WikiSQL test data...")
|
||||
self.test = load_wikisql_split("test")
|
||||
print(f"Loaded {len(self.test)} test examples")
|
||||
|
||||
random.shuffle(self.train)
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
"""Save checkpoint with iteration state."""
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
def _score_sql(
|
||||
self,
|
||||
generated_sql: str,
|
||||
gold_sql: str,
|
||||
header: List[str],
|
||||
rows: List[List[Any]],
|
||||
) -> Tuple[float, bool]:
|
||||
"""
|
||||
Score SQL by execution comparison.
|
||||
|
||||
Returns:
|
||||
Tuple of (score, execution_success)
|
||||
"""
|
||||
if not generated_sql:
|
||||
return -1.0, False
|
||||
|
||||
# Create in-memory table
|
||||
try:
|
||||
conn = create_table_from_wikisql(header, rows)
|
||||
except Exception:
|
||||
return -1.0, False
|
||||
|
||||
# Execute gold SQL
|
||||
gold_result = execute_sql(conn, gold_sql)
|
||||
if gold_result is None:
|
||||
conn.close()
|
||||
return -1.0, False
|
||||
|
||||
# Execute generated SQL
|
||||
gen_result = execute_sql(conn, generated_sql)
|
||||
conn.close()
|
||||
|
||||
if gen_result is None:
|
||||
return -1.0, False
|
||||
|
||||
# Compare results
|
||||
if results_match(gen_result, gold_result):
|
||||
return 1.0, True
|
||||
else:
|
||||
return -1.0, True
|
||||
|
||||
async def rollout_and_score_eval(
|
||||
self, question: str, gold_sql: str, header: List[str], rows: List[List[Any]]
|
||||
) -> dict:
|
||||
"""Rollout and score a single evaluation item."""
|
||||
table_schema = format_table_schema(header, rows)
|
||||
user_content = f"{table_schema}\nQuestion: {question}"
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.6,
|
||||
)
|
||||
response_content = completion.choices[0].message.content
|
||||
|
||||
# Extract and score generated SQL
|
||||
generated_sql = extract_boxed_sql(response_content)
|
||||
score, exec_success = self._score_sql(generated_sql, gold_sql, header, rows)
|
||||
correct = score == 1.0
|
||||
|
||||
sample = {
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": response_content},
|
||||
],
|
||||
"question": question,
|
||||
"gold_sql": gold_sql,
|
||||
"generated_sql": generated_sql,
|
||||
"score": 1 if correct else 0,
|
||||
"correct": correct,
|
||||
"execution_success": exec_success,
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
}
|
||||
|
||||
return {"score": 1 if correct else 0, "sample": sample}
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Run evaluation on test set."""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
eval_tasks = []
|
||||
# Limit eval to first 500 items for efficiency
|
||||
for item in self.test[:500]:
|
||||
eval_tasks.append(
|
||||
self.rollout_and_score_eval(
|
||||
item["question"], item["gold_sql"], item["header"], item["rows"]
|
||||
)
|
||||
)
|
||||
results = await tqdm_asyncio.gather(*eval_tasks)
|
||||
|
||||
scores = [result["score"] for result in results]
|
||||
samples = [result["sample"] for result in results]
|
||||
|
||||
percent_correct = sum(scores) / len(scores) if scores else 0
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
||||
|
||||
eval_metrics = {
|
||||
"eval/percent_correct": percent_correct,
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": 0.6,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: WikiSQLRow
|
||||
) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
"""Generate SQL queries for a given question."""
|
||||
table_schema = format_table_schema(item["header"], item["rows"])
|
||||
user_content = f"{table_schema}\nQuestion: {item['question']}"
|
||||
user_message = {"role": "user", "content": user_content}
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
chat_completions = await managed.chat_completion(
|
||||
messages=[{"role": "system", "content": system_prompt}, user_message],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=1.0,
|
||||
)
|
||||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
user_message,
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"gold_sql": item["gold_sql"],
|
||||
"header": item["header"],
|
||||
"rows": item["rows"],
|
||||
"finish_reason": chat_completion.finish_reason,
|
||||
"tokens": nodes[i].tokens,
|
||||
"masks": nodes[i].masked_tokens,
|
||||
"logprobs": nodes[i].logprobs,
|
||||
}
|
||||
)
|
||||
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
"""Score generated SQL queries by execution comparison."""
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
|
||||
# Get table info from first item
|
||||
gold_sql = rollout_group_data[0]["gold_sql"]
|
||||
header = rollout_group_data[0]["header"]
|
||||
rows = rollout_group_data[0]["rows"]
|
||||
|
||||
random.shuffle(rollout_group_data)
|
||||
|
||||
for item in rollout_group_data:
|
||||
response_content = item["messages"][-1]["content"]
|
||||
|
||||
# Extract SQL from response
|
||||
generated_sql = extract_boxed_sql(response_content)
|
||||
|
||||
# Score by execution
|
||||
reward, exec_success = self._score_sql(
|
||||
generated_sql, gold_sql, header, rows
|
||||
)
|
||||
self.execution_success_buffer.append(1 if exec_success else 0)
|
||||
|
||||
tokens = item["tokens"]
|
||||
masks = item["masks"]
|
||||
logprobs = item["logprobs"]
|
||||
|
||||
# Remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["inference_logprobs"].append(logprobs)
|
||||
scores["scores"].append(reward)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
for score in scores["scores"]:
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
|
||||
# Check if all scores are the same
|
||||
if all([score == 1 for score in scores["scores"]]):
|
||||
# Apply length penalty when all are correct
|
||||
token_lengths = [len(token) for token in scores["tokens"]]
|
||||
if max(token_lengths) == 0:
|
||||
return None
|
||||
|
||||
max_allowed_length = self.config.max_token_length
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
|
||||
scores["scores"] = []
|
||||
for length in token_lengths:
|
||||
if length <= length_threshold:
|
||||
scores["scores"].append(1.0)
|
||||
else:
|
||||
percentage_of_range = (length - length_threshold) / (
|
||||
max_allowed_length - length_threshold
|
||||
)
|
||||
percentage_of_range = min(percentage_of_range, 1.0)
|
||||
scores["scores"].append(1.0 - percentage_of_range)
|
||||
|
||||
if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
return None # If all the same, return None
|
||||
|
||||
return scores
|
||||
|
||||
async def get_next_item(self) -> WikiSQLRow:
|
||||
"""Get the next training item."""
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return next_item
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SQLQueryEnv.cli()
|
||||
276
environments/community/sql_query_env/test_integration.py
Normal file
276
environments/community/sql_query_env/test_integration.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration test for SQL Query Environment that works with OpenAI-compatible APIs.
|
||||
|
||||
This test verifies:
|
||||
1. WikiSQL dataset loading
|
||||
2. SQL generation from LLM
|
||||
3. SQL extraction from \boxed{}
|
||||
4. SQL execution and result comparison
|
||||
5. Scoring logic
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, List
|
||||
|
||||
import openai
|
||||
|
||||
# Import local modules
|
||||
from sql_executor import (
|
||||
create_table_from_wikisql,
|
||||
execute_sql,
|
||||
extract_boxed_sql,
|
||||
quote_identifiers_in_sql,
|
||||
results_match,
|
||||
)
|
||||
from wikisql_loader import load_wikisql_split
|
||||
|
||||
# System prompt from the environment
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought "
|
||||
"to deeply consider the problem and deliberate with yourself via systematic "
|
||||
"reasoning processes to help come to a correct solution prior to answering. "
|
||||
"You should enclose your thoughts and internal monologue inside <think> </think> "
|
||||
"tags, and then provide your solution or response to the problem.\n\n"
|
||||
"You are a SQL expert. Given a table schema and a natural language question, "
|
||||
"generate a SQL query that answers the question.\n\n"
|
||||
"You are allocated a maximum of 1024 tokens, please strive to use less.\n\n"
|
||||
"Provide your SQL query inside \\boxed{} like this: "
|
||||
"\\boxed{SELECT column FROM table WHERE condition}\n\n"
|
||||
"Important:\n"
|
||||
"- Use only the columns provided in the table schema\n"
|
||||
'- The table is always named "data"\n'
|
||||
"- Ensure your SQL syntax is valid SQLite\n\n"
|
||||
"So please end your answer with \\boxed{your SQL query here}"
|
||||
)
|
||||
|
||||
|
||||
def format_table_schema(
|
||||
header: List[str], rows: List[List[Any]], max_rows: int = 3
|
||||
) -> str:
|
||||
"""Format table schema for the prompt."""
|
||||
schema = f"Table: data\nColumns: {', '.join(header)}\n"
|
||||
if rows:
|
||||
schema += "Sample data:\n"
|
||||
for row in rows[:max_rows]:
|
||||
row_str = " | ".join(str(v) for v in row)
|
||||
schema += f" {row_str}\n"
|
||||
return schema
|
||||
|
||||
|
||||
def score_sql(
|
||||
generated_sql: str,
|
||||
gold_sql: str,
|
||||
header: List[str],
|
||||
rows: List[List[Any]],
|
||||
) -> dict:
|
||||
"""Score SQL by execution comparison."""
|
||||
result = {
|
||||
"generated_sql": generated_sql,
|
||||
"gold_sql": gold_sql,
|
||||
"score": -1.0,
|
||||
"execution_success": False,
|
||||
"gen_result": None,
|
||||
"gold_result": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
if not generated_sql:
|
||||
result["error"] = "No SQL extracted from response"
|
||||
return result
|
||||
|
||||
# Create in-memory table
|
||||
try:
|
||||
conn = create_table_from_wikisql(header, rows)
|
||||
except Exception as e:
|
||||
result["error"] = f"Table creation failed: {e}"
|
||||
return result
|
||||
|
||||
# Quote identifiers in gold SQL that need quoting
|
||||
quoted_gold_sql = quote_identifiers_in_sql(gold_sql, header)
|
||||
result["quoted_gold_sql"] = quoted_gold_sql
|
||||
|
||||
# Execute gold SQL
|
||||
gold_result = execute_sql(conn, quoted_gold_sql)
|
||||
if gold_result is None:
|
||||
conn.close()
|
||||
result["error"] = f"Gold SQL execution failed: {quoted_gold_sql}"
|
||||
return result
|
||||
result["gold_result"] = str(gold_result)
|
||||
|
||||
# Execute generated SQL
|
||||
gen_result = execute_sql(conn, generated_sql)
|
||||
conn.close()
|
||||
|
||||
if gen_result is None:
|
||||
result["error"] = "Generated SQL execution failed"
|
||||
return result
|
||||
|
||||
result["gen_result"] = str(gen_result)
|
||||
result["execution_success"] = True
|
||||
|
||||
# Compare results
|
||||
if results_match(gen_result, gold_result):
|
||||
result["score"] = 1.0
|
||||
else:
|
||||
result["error"] = "Results do not match"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def test_single_item(client, model_name: str, item: dict, item_idx: int) -> dict:
|
||||
"""Test a single WikiSQL item."""
|
||||
table_schema = format_table_schema(item["header"], item["rows"])
|
||||
user_content = f"{table_schema}\nQuestion: {item['question']}"
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
max_tokens=1024,
|
||||
temperature=0.6,
|
||||
)
|
||||
|
||||
response_content = response.choices[0].message.content
|
||||
|
||||
# Extract SQL
|
||||
generated_sql = extract_boxed_sql(response_content)
|
||||
|
||||
# Score
|
||||
score_result = score_sql(
|
||||
generated_sql, item["gold_sql"], item["header"], item["rows"]
|
||||
)
|
||||
|
||||
return {
|
||||
"item_idx": item_idx,
|
||||
"question": item["question"],
|
||||
"response": (
|
||||
response_content[:500] + "..."
|
||||
if len(response_content) > 500
|
||||
else response_content
|
||||
),
|
||||
**score_result,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"item_idx": item_idx,
|
||||
"question": item["question"],
|
||||
"error": str(e),
|
||||
"score": -1.0,
|
||||
}
|
||||
|
||||
|
||||
async def run_integration_test(
|
||||
base_url: str,
|
||||
model_name: str,
|
||||
api_key: str = "x",
|
||||
num_samples: int = 10,
|
||||
):
|
||||
"""Run the integration test."""
|
||||
print(f"\n{'='*60}")
|
||||
print("SQL Query Environment Integration Test")
|
||||
print(f"{'='*60}")
|
||||
print(f"Server: {base_url}")
|
||||
print(f"Model: {model_name}")
|
||||
print(f"Samples: {num_samples}")
|
||||
print()
|
||||
|
||||
# Load dataset
|
||||
print("Loading WikiSQL training data...")
|
||||
train_data = load_wikisql_split("train")
|
||||
print(f"Loaded {len(train_data)} training examples")
|
||||
|
||||
# Initialize OpenAI client
|
||||
client = openai.AsyncClient(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
# Run tests
|
||||
print(f"\nTesting {num_samples} samples...\n")
|
||||
results = []
|
||||
|
||||
for i in range(min(num_samples, len(train_data))):
|
||||
item = train_data[i]
|
||||
print(f"[{i+1}/{num_samples}] Testing: {item['question'][:60]}...")
|
||||
result = await test_single_item(client, model_name, item, i)
|
||||
results.append(result)
|
||||
|
||||
# Print result
|
||||
if result["score"] == 1.0:
|
||||
print(f" ✓ CORRECT - Generated: {result.get('generated_sql', 'N/A')[:80]}")
|
||||
else:
|
||||
print(f" ✗ INCORRECT - {result.get('error', 'Unknown error')}")
|
||||
if result.get("generated_sql"):
|
||||
print(f" Generated: {result['generated_sql'][:80]}")
|
||||
print(f" Gold: {result.get('gold_sql', 'N/A')[:80]}")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
print("SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
|
||||
correct = sum(1 for r in results if r["score"] == 1.0)
|
||||
executed = sum(1 for r in results if r.get("execution_success", False))
|
||||
|
||||
print(f"Correct: {correct}/{num_samples} ({100*correct/num_samples:.1f}%)")
|
||||
print(
|
||||
f"Execution Success: {executed}/{num_samples} ({100*executed/num_samples:.1f}%)"
|
||||
)
|
||||
|
||||
# Save results
|
||||
output_file = "integration_test_results.json"
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nDetailed results saved to: {output_file}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SQL Query Environment Integration Test"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="Base URL for OpenAI-compatible API",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen3-8B",
|
||||
help="Model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api_key",
|
||||
type=str,
|
||||
default="x",
|
||||
help="API key",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of samples to test",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(
|
||||
run_integration_test(
|
||||
base_url=args.base_url,
|
||||
model_name=args.model,
|
||||
api_key=args.api_key,
|
||||
num_samples=args.num_samples,
|
||||
)
|
||||
)
|
||||
224
environments/community/sql_query_env/test_sql_executor.py
Normal file
224
environments/community/sql_query_env/test_sql_executor.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
"""
|
||||
Unit tests for SQL executor module.
|
||||
"""
|
||||
|
||||
from sql_executor import (
|
||||
create_table_from_wikisql,
|
||||
execute_sql,
|
||||
extract_boxed_sql,
|
||||
normalize_result,
|
||||
results_match,
|
||||
)
|
||||
|
||||
|
||||
class TestCreateTableFromWikiSQL:
|
||||
"""Tests for create_table_from_wikisql function."""
|
||||
|
||||
def test_basic_table_creation(self):
|
||||
"""Test creating a simple table."""
|
||||
header = ["name", "age", "city"]
|
||||
rows = [
|
||||
["Alice", "30", "New York"],
|
||||
["Bob", "25", "Los Angeles"],
|
||||
]
|
||||
|
||||
conn = create_table_from_wikisql(header, rows)
|
||||
result = execute_sql(conn, "SELECT * FROM data")
|
||||
conn.close()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
|
||||
def test_table_with_special_column_names(self):
|
||||
"""Test columns with spaces and special characters."""
|
||||
header = ["Player Name", "Goals-Scored", "Team (2024)"]
|
||||
rows = [["Messi", "10", "Inter Miami"]]
|
||||
|
||||
conn = create_table_from_wikisql(header, rows)
|
||||
# Column names should be sanitized
|
||||
result = execute_sql(conn, "SELECT * FROM data")
|
||||
conn.close()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
|
||||
def test_empty_table(self):
|
||||
"""Test creating a table with no rows."""
|
||||
header = ["col1", "col2"]
|
||||
rows = []
|
||||
|
||||
conn = create_table_from_wikisql(header, rows)
|
||||
result = execute_sql(conn, "SELECT * FROM data")
|
||||
conn.close()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestExecuteSQL:
|
||||
"""Tests for execute_sql function."""
|
||||
|
||||
def test_valid_select(self):
|
||||
"""Test a valid SELECT query."""
|
||||
header = ["name", "age"]
|
||||
rows = [["Alice", "30"], ["Bob", "25"]]
|
||||
|
||||
conn = create_table_from_wikisql(header, rows)
|
||||
result = execute_sql(conn, "SELECT name FROM data WHERE age = '30'")
|
||||
conn.close()
|
||||
|
||||
assert result == [("Alice",)]
|
||||
|
||||
def test_invalid_sql_returns_none(self):
|
||||
"""Test that invalid SQL returns None."""
|
||||
header = ["a"]
|
||||
rows = [["1"]]
|
||||
|
||||
conn = create_table_from_wikisql(header, rows)
|
||||
result = execute_sql(conn, "INVALID SQL SYNTAX")
|
||||
conn.close()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_aggregation(self):
|
||||
"""Test COUNT aggregation."""
|
||||
header = ["category"]
|
||||
rows = [["A"], ["A"], ["B"]]
|
||||
|
||||
conn = create_table_from_wikisql(header, rows)
|
||||
result = execute_sql(conn, "SELECT COUNT(*) FROM data WHERE category = 'A'")
|
||||
conn.close()
|
||||
|
||||
assert result == [(2,)]
|
||||
|
||||
|
||||
class TestExtractBoxedSQL:
|
||||
"""Tests for extract_boxed_sql function."""
|
||||
|
||||
def test_simple_boxed(self):
|
||||
"""Test extracting from simple boxed format."""
|
||||
text = "The answer is \\boxed{SELECT * FROM data}"
|
||||
result = extract_boxed_sql(text)
|
||||
assert result == "SELECT * FROM data"
|
||||
|
||||
def test_boxed_with_think_tags(self):
|
||||
"""Test extracting when thinking tags are present."""
|
||||
text = """<think>
|
||||
Let me think about this query...
|
||||
</think>
|
||||
|
||||
\\boxed{SELECT name FROM data WHERE age > 25}"""
|
||||
result = extract_boxed_sql(text)
|
||||
assert result == "SELECT name FROM data WHERE age > 25"
|
||||
|
||||
def test_no_boxed_returns_none(self):
|
||||
"""Test that missing boxed returns None."""
|
||||
text = "SELECT * FROM data"
|
||||
result = extract_boxed_sql(text)
|
||||
assert result is None
|
||||
|
||||
def test_empty_boxed(self):
|
||||
"""Test empty boxed returns None."""
|
||||
text = "\\boxed{}"
|
||||
result = extract_boxed_sql(text)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestResultsMatch:
|
||||
"""Tests for results_match function."""
|
||||
|
||||
def test_identical_results(self):
|
||||
"""Test matching identical results."""
|
||||
result1 = [("Alice", "30"), ("Bob", "25")]
|
||||
result2 = [("Alice", "30"), ("Bob", "25")]
|
||||
assert results_match(result1, result2) is True
|
||||
|
||||
def test_different_order(self):
|
||||
"""Test matching results in different order."""
|
||||
result1 = [("Alice",), ("Bob",)]
|
||||
result2 = [("Bob",), ("Alice",)]
|
||||
assert results_match(result1, result2) is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
"""Test case-insensitive matching."""
|
||||
result1 = [("ALICE",)]
|
||||
result2 = [("alice",)]
|
||||
assert results_match(result1, result2) is True
|
||||
|
||||
def test_different_results(self):
|
||||
"""Test non-matching results."""
|
||||
result1 = [("Alice",)]
|
||||
result2 = [("Bob",)]
|
||||
assert results_match(result1, result2) is False
|
||||
|
||||
def test_none_result(self):
|
||||
"""Test that None results don't match."""
|
||||
assert results_match(None, [("a",)]) is False
|
||||
assert results_match([("a",)], None) is False
|
||||
|
||||
|
||||
class TestNormalizeResult:
|
||||
"""Tests for normalize_result function."""
|
||||
|
||||
def test_normalization(self):
|
||||
"""Test result normalization."""
|
||||
result = [("Alice", "30"), ("BOB", "25")]
|
||||
normalized = normalize_result(result)
|
||||
|
||||
# Should be lowercase and sorted
|
||||
assert normalized == [("25", "bob"), ("30", "alice")] or normalized == [
|
||||
("alice", "30"),
|
||||
("bob", "25"),
|
||||
]
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for the full scoring pipeline."""
|
||||
|
||||
def test_correct_sql_scores_1(self):
|
||||
"""Test that matching SQL results score correctly."""
|
||||
header = ["name", "age"]
|
||||
rows = [["Alice", "30"], ["Bob", "25"]]
|
||||
|
||||
conn = create_table_from_wikisql(header, rows)
|
||||
|
||||
gold_sql = "SELECT name FROM data WHERE age = '30'"
|
||||
generated_sql = "SELECT name FROM data WHERE age = '30'"
|
||||
|
||||
gold_result = execute_sql(conn, gold_sql)
|
||||
gen_result = execute_sql(conn, generated_sql)
|
||||
conn.close()
|
||||
|
||||
assert results_match(gen_result, gold_result) is True
|
||||
|
||||
def test_semantically_equivalent_sql(self):
|
||||
"""Test that semantically equivalent but differently written SQL matches."""
|
||||
header = ["name", "age"]
|
||||
rows = [["Alice", "30"], ["Bob", "25"]]
|
||||
|
||||
conn = create_table_from_wikisql(header, rows)
|
||||
|
||||
gold_sql = "SELECT name FROM data WHERE age = '30'"
|
||||
# Different whitespace/formatting but same result
|
||||
generated_sql = "select NAME from data where AGE='30'"
|
||||
|
||||
gold_result = execute_sql(conn, gold_sql)
|
||||
gen_result = execute_sql(conn, generated_sql)
|
||||
conn.close()
|
||||
|
||||
assert results_match(gen_result, gold_result) is True
|
||||
|
||||
def test_wrong_sql_fails(self):
|
||||
"""Test that incorrect SQL doesn't match."""
|
||||
header = ["name", "age"]
|
||||
rows = [["Alice", "30"], ["Bob", "25"]]
|
||||
|
||||
conn = create_table_from_wikisql(header, rows)
|
||||
|
||||
gold_sql = "SELECT name FROM data WHERE age = '30'"
|
||||
generated_sql = "SELECT name FROM data WHERE age = '25'"
|
||||
|
||||
gold_result = execute_sql(conn, gold_sql)
|
||||
gen_result = execute_sql(conn, generated_sql)
|
||||
conn.close()
|
||||
|
||||
assert results_match(gen_result, gold_result) is False
|
||||
227
environments/community/sql_query_env/wikisql_loader.py
Normal file
227
environments/community/sql_query_env/wikisql_loader.py
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
"""
|
||||
WikiSQL Data Loader
|
||||
|
||||
Downloads and loads the WikiSQL dataset directly from the Salesforce GitHub repository.
|
||||
This module provides functionality to fetch the dataset, extract it, and load it
|
||||
in a format compatible with the SQL Query Environment.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
# WikiSQL data source
|
||||
WIKISQL_DATA_URL = "https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2"
|
||||
|
||||
# SQL operators from WikiSQL lib/query.py
|
||||
AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
|
||||
COND_OPS = ["=", ">", "<", "OP"]
|
||||
|
||||
|
||||
def get_cache_dir() -> Path:
|
||||
"""Get the cache directory for WikiSQL data."""
|
||||
cache_dir = Path(
|
||||
os.environ.get("WIKISQL_CACHE_DIR", Path.home() / ".cache" / "wikisql")
|
||||
)
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
|
||||
def download_wikisql(cache_dir: Optional[Path] = None, force: bool = False) -> Path:
|
||||
"""
|
||||
Download the WikiSQL dataset from GitHub.
|
||||
|
||||
Args:
|
||||
cache_dir: Directory to cache the downloaded data
|
||||
force: Force re-download even if already cached
|
||||
|
||||
Returns:
|
||||
Path to the extracted data directory
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = get_cache_dir()
|
||||
|
||||
data_dir = cache_dir / "data"
|
||||
tar_file = cache_dir / "data.tar.bz2"
|
||||
|
||||
# Check if already extracted
|
||||
if data_dir.exists() and not force:
|
||||
expected_files = [
|
||||
"train.jsonl",
|
||||
"dev.jsonl",
|
||||
"test.jsonl",
|
||||
"train.db",
|
||||
"dev.db",
|
||||
"test.db",
|
||||
]
|
||||
if all((data_dir / f).exists() for f in expected_files):
|
||||
return data_dir
|
||||
|
||||
# Download if needed
|
||||
if not tar_file.exists() or force:
|
||||
print(f"Downloading WikiSQL dataset from {WIKISQL_DATA_URL}...")
|
||||
urlretrieve(WIKISQL_DATA_URL, tar_file)
|
||||
print(f"Downloaded to {tar_file}")
|
||||
|
||||
# Extract
|
||||
print(f"Extracting to {cache_dir}...")
|
||||
with tarfile.open(tar_file, "r:bz2") as tar:
|
||||
tar.extractall(cache_dir)
|
||||
print("Extraction complete.")
|
||||
|
||||
return data_dir
|
||||
|
||||
|
||||
def build_human_readable_sql(
|
||||
table_header: List[str],
|
||||
sel: int,
|
||||
agg: int,
|
||||
conds: List[Tuple[int, int, str]],
|
||||
) -> str:
|
||||
"""
|
||||
Build a human-readable SQL query from WikiSQL format.
|
||||
|
||||
Args:
|
||||
table_header: List of column names
|
||||
sel: Index of selected column
|
||||
agg: Index of aggregation operator
|
||||
conds: List of (column_index, operator_index, condition) tuples
|
||||
|
||||
Returns:
|
||||
Human-readable SQL query string
|
||||
"""
|
||||
# Build SELECT clause
|
||||
sel_col = table_header[sel] if sel < len(table_header) else f"col{sel}"
|
||||
if agg > 0:
|
||||
select_clause = f"{AGG_OPS[agg]}({sel_col})"
|
||||
else:
|
||||
select_clause = sel_col
|
||||
|
||||
sql = f"SELECT {select_clause} FROM data"
|
||||
|
||||
# Build WHERE clause
|
||||
if conds:
|
||||
where_parts = []
|
||||
for col_idx, op_idx, condition in conds:
|
||||
col_name = (
|
||||
table_header[col_idx]
|
||||
if col_idx < len(table_header)
|
||||
else f"col{col_idx}"
|
||||
)
|
||||
op = COND_OPS[op_idx] if op_idx < len(COND_OPS) else "="
|
||||
# Quote string conditions
|
||||
if isinstance(condition, str):
|
||||
where_parts.append(f"{col_name} {op} '{condition}'")
|
||||
else:
|
||||
where_parts.append(f"{col_name} {op} {condition}")
|
||||
sql += " WHERE " + " AND ".join(where_parts)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def load_wikisql_split(
|
||||
split: str = "train",
|
||||
cache_dir: Optional[Path] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load a WikiSQL dataset split.
|
||||
|
||||
Args:
|
||||
split: One of 'train', 'dev', or 'test'
|
||||
cache_dir: Directory containing the WikiSQL data
|
||||
|
||||
Returns:
|
||||
List of dictionaries with question, table info, and gold SQL
|
||||
"""
|
||||
if cache_dir is None:
|
||||
data_dir = download_wikisql()
|
||||
else:
|
||||
data_dir = cache_dir
|
||||
|
||||
jsonl_file = data_dir / f"{split}.jsonl"
|
||||
tables_file = data_dir / f"{split}.tables.jsonl"
|
||||
|
||||
if not jsonl_file.exists():
|
||||
raise FileNotFoundError(f"WikiSQL {split} JSONL file not found: {jsonl_file}")
|
||||
if not tables_file.exists():
|
||||
raise FileNotFoundError(f"WikiSQL {split} tables file not found: {tables_file}")
|
||||
|
||||
# Load all tables from tables.jsonl (has real column headers)
|
||||
tables_cache = {}
|
||||
with open(tables_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
table = json.loads(line.strip())
|
||||
tables_cache[table["id"]] = {
|
||||
"header": table["header"],
|
||||
"rows": table["rows"],
|
||||
"types": table.get("types", ["text"] * len(table["header"])),
|
||||
}
|
||||
|
||||
print(f"Loaded {len(tables_cache)} tables for {split} split")
|
||||
|
||||
# Load questions and build dataset
|
||||
dataset = []
|
||||
skipped = 0
|
||||
with open(jsonl_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
item = json.loads(line.strip())
|
||||
|
||||
table_info = tables_cache.get(item["table_id"])
|
||||
if not table_info or not table_info["header"]:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Build human-readable SQL
|
||||
conds = [(c[0], c[1], c[2]) for c in item["sql"]["conds"]]
|
||||
gold_sql = build_human_readable_sql(
|
||||
table_info["header"],
|
||||
item["sql"]["sel"],
|
||||
item["sql"]["agg"],
|
||||
conds,
|
||||
)
|
||||
|
||||
dataset.append(
|
||||
{
|
||||
"question": item["question"],
|
||||
"header": table_info["header"],
|
||||
"rows": table_info["rows"],
|
||||
"types": table_info["types"],
|
||||
"gold_sql": gold_sql,
|
||||
"table_id": item["table_id"],
|
||||
}
|
||||
)
|
||||
|
||||
if skipped > 0:
|
||||
print(f"Skipped {skipped} examples with missing tables")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def load_wikisql(cache_dir: Optional[Path] = None) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Load the full WikiSQL dataset.
|
||||
|
||||
Returns:
|
||||
Dictionary with 'train', 'dev', and 'test' splits
|
||||
"""
|
||||
return {
|
||||
"train": load_wikisql_split("train", cache_dir),
|
||||
"dev": load_wikisql_split("dev", cache_dir),
|
||||
"test": load_wikisql_split("test", cache_dir),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the loader
|
||||
print("Testing WikiSQL loader...")
|
||||
train = load_wikisql_split("train")
|
||||
print(f"Loaded {len(train)} training examples")
|
||||
if train:
|
||||
print("\nFirst example:")
|
||||
print(f" Question: {train[0]['question']}")
|
||||
print(f" Header: {train[0]['header']}")
|
||||
print(f" Gold SQL: {train[0]['gold_sql']}")
|
||||
print(f" Rows (first 2): {train[0]['rows'][:2]}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue