mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Add SQL Query Generation Environment
This commit is contained in:
parent
6d79d4c7ad
commit
c927794248
6 changed files with 1427 additions and 0 deletions
227
environments/community/sql_query_env/wikisql_loader.py
Normal file
227
environments/community/sql_query_env/wikisql_loader.py
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
"""
|
||||
WikiSQL Data Loader
|
||||
|
||||
Downloads and loads the WikiSQL dataset directly from the Salesforce GitHub repository.
|
||||
This module provides functionality to fetch the dataset, extract it, and load it
|
||||
in a format compatible with the SQL Query Environment.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
# WikiSQL data source
|
||||
WIKISQL_DATA_URL = "https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2"
|
||||
|
||||
# SQL operators from WikiSQL lib/query.py
|
||||
AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
|
||||
COND_OPS = ["=", ">", "<", "OP"]
|
||||
|
||||
|
||||
def get_cache_dir() -> Path:
|
||||
"""Get the cache directory for WikiSQL data."""
|
||||
cache_dir = Path(
|
||||
os.environ.get("WIKISQL_CACHE_DIR", Path.home() / ".cache" / "wikisql")
|
||||
)
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
|
||||
def download_wikisql(cache_dir: Optional[Path] = None, force: bool = False) -> Path:
|
||||
"""
|
||||
Download the WikiSQL dataset from GitHub.
|
||||
|
||||
Args:
|
||||
cache_dir: Directory to cache the downloaded data
|
||||
force: Force re-download even if already cached
|
||||
|
||||
Returns:
|
||||
Path to the extracted data directory
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = get_cache_dir()
|
||||
|
||||
data_dir = cache_dir / "data"
|
||||
tar_file = cache_dir / "data.tar.bz2"
|
||||
|
||||
# Check if already extracted
|
||||
if data_dir.exists() and not force:
|
||||
expected_files = [
|
||||
"train.jsonl",
|
||||
"dev.jsonl",
|
||||
"test.jsonl",
|
||||
"train.db",
|
||||
"dev.db",
|
||||
"test.db",
|
||||
]
|
||||
if all((data_dir / f).exists() for f in expected_files):
|
||||
return data_dir
|
||||
|
||||
# Download if needed
|
||||
if not tar_file.exists() or force:
|
||||
print(f"Downloading WikiSQL dataset from {WIKISQL_DATA_URL}...")
|
||||
urlretrieve(WIKISQL_DATA_URL, tar_file)
|
||||
print(f"Downloaded to {tar_file}")
|
||||
|
||||
# Extract
|
||||
print(f"Extracting to {cache_dir}...")
|
||||
with tarfile.open(tar_file, "r:bz2") as tar:
|
||||
tar.extractall(cache_dir)
|
||||
print("Extraction complete.")
|
||||
|
||||
return data_dir
|
||||
|
||||
|
||||
def build_human_readable_sql(
|
||||
table_header: List[str],
|
||||
sel: int,
|
||||
agg: int,
|
||||
conds: List[Tuple[int, int, str]],
|
||||
) -> str:
|
||||
"""
|
||||
Build a human-readable SQL query from WikiSQL format.
|
||||
|
||||
Args:
|
||||
table_header: List of column names
|
||||
sel: Index of selected column
|
||||
agg: Index of aggregation operator
|
||||
conds: List of (column_index, operator_index, condition) tuples
|
||||
|
||||
Returns:
|
||||
Human-readable SQL query string
|
||||
"""
|
||||
# Build SELECT clause
|
||||
sel_col = table_header[sel] if sel < len(table_header) else f"col{sel}"
|
||||
if agg > 0:
|
||||
select_clause = f"{AGG_OPS[agg]}({sel_col})"
|
||||
else:
|
||||
select_clause = sel_col
|
||||
|
||||
sql = f"SELECT {select_clause} FROM data"
|
||||
|
||||
# Build WHERE clause
|
||||
if conds:
|
||||
where_parts = []
|
||||
for col_idx, op_idx, condition in conds:
|
||||
col_name = (
|
||||
table_header[col_idx]
|
||||
if col_idx < len(table_header)
|
||||
else f"col{col_idx}"
|
||||
)
|
||||
op = COND_OPS[op_idx] if op_idx < len(COND_OPS) else "="
|
||||
# Quote string conditions
|
||||
if isinstance(condition, str):
|
||||
where_parts.append(f"{col_name} {op} '{condition}'")
|
||||
else:
|
||||
where_parts.append(f"{col_name} {op} {condition}")
|
||||
sql += " WHERE " + " AND ".join(where_parts)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def load_wikisql_split(
|
||||
split: str = "train",
|
||||
cache_dir: Optional[Path] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load a WikiSQL dataset split.
|
||||
|
||||
Args:
|
||||
split: One of 'train', 'dev', or 'test'
|
||||
cache_dir: Directory containing the WikiSQL data
|
||||
|
||||
Returns:
|
||||
List of dictionaries with question, table info, and gold SQL
|
||||
"""
|
||||
if cache_dir is None:
|
||||
data_dir = download_wikisql()
|
||||
else:
|
||||
data_dir = cache_dir
|
||||
|
||||
jsonl_file = data_dir / f"{split}.jsonl"
|
||||
tables_file = data_dir / f"{split}.tables.jsonl"
|
||||
|
||||
if not jsonl_file.exists():
|
||||
raise FileNotFoundError(f"WikiSQL {split} JSONL file not found: {jsonl_file}")
|
||||
if not tables_file.exists():
|
||||
raise FileNotFoundError(f"WikiSQL {split} tables file not found: {tables_file}")
|
||||
|
||||
# Load all tables from tables.jsonl (has real column headers)
|
||||
tables_cache = {}
|
||||
with open(tables_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
table = json.loads(line.strip())
|
||||
tables_cache[table["id"]] = {
|
||||
"header": table["header"],
|
||||
"rows": table["rows"],
|
||||
"types": table.get("types", ["text"] * len(table["header"])),
|
||||
}
|
||||
|
||||
print(f"Loaded {len(tables_cache)} tables for {split} split")
|
||||
|
||||
# Load questions and build dataset
|
||||
dataset = []
|
||||
skipped = 0
|
||||
with open(jsonl_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
item = json.loads(line.strip())
|
||||
|
||||
table_info = tables_cache.get(item["table_id"])
|
||||
if not table_info or not table_info["header"]:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Build human-readable SQL
|
||||
conds = [(c[0], c[1], c[2]) for c in item["sql"]["conds"]]
|
||||
gold_sql = build_human_readable_sql(
|
||||
table_info["header"],
|
||||
item["sql"]["sel"],
|
||||
item["sql"]["agg"],
|
||||
conds,
|
||||
)
|
||||
|
||||
dataset.append(
|
||||
{
|
||||
"question": item["question"],
|
||||
"header": table_info["header"],
|
||||
"rows": table_info["rows"],
|
||||
"types": table_info["types"],
|
||||
"gold_sql": gold_sql,
|
||||
"table_id": item["table_id"],
|
||||
}
|
||||
)
|
||||
|
||||
if skipped > 0:
|
||||
print(f"Skipped {skipped} examples with missing tables")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def load_wikisql(cache_dir: Optional[Path] = None) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Load the full WikiSQL dataset.
|
||||
|
||||
Returns:
|
||||
Dictionary with 'train', 'dev', and 'test' splits
|
||||
"""
|
||||
return {
|
||||
"train": load_wikisql_split("train", cache_dir),
|
||||
"dev": load_wikisql_split("dev", cache_dir),
|
||||
"test": load_wikisql_split("test", cache_dir),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the loader
|
||||
print("Testing WikiSQL loader...")
|
||||
train = load_wikisql_split("train")
|
||||
print(f"Loaded {len(train)} training examples")
|
||||
if train:
|
||||
print("\nFirst example:")
|
||||
print(f" Question: {train[0]['question']}")
|
||||
print(f" Header: {train[0]['header']}")
|
||||
print(f" Gold SQL: {train[0]['gold_sql']}")
|
||||
print(f" Rows (first 2): {train[0]['rows'][:2]}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue