Add SQL Query Generation Environment

This commit is contained in:
PLippmann 2026-01-05 16:43:34 +01:00
parent 6d79d4c7ad
commit c927794248
6 changed files with 1427 additions and 0 deletions

View file

@ -0,0 +1,227 @@
"""
WikiSQL Data Loader
Downloads and loads the WikiSQL dataset directly from the Salesforce GitHub repository.
This module provides functionality to fetch the dataset, extract it, and load it
in a format compatible with the SQL Query Environment.
"""
import json
import os
import tarfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from urllib.request import urlretrieve
# WikiSQL data source
WIKISQL_DATA_URL = "https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2"
# SQL operators from WikiSQL lib/query.py
AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
COND_OPS = ["=", ">", "<", "OP"]
def get_cache_dir() -> Path:
"""Get the cache directory for WikiSQL data."""
cache_dir = Path(
os.environ.get("WIKISQL_CACHE_DIR", Path.home() / ".cache" / "wikisql")
)
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
def download_wikisql(cache_dir: Optional[Path] = None, force: bool = False) -> Path:
"""
Download the WikiSQL dataset from GitHub.
Args:
cache_dir: Directory to cache the downloaded data
force: Force re-download even if already cached
Returns:
Path to the extracted data directory
"""
if cache_dir is None:
cache_dir = get_cache_dir()
data_dir = cache_dir / "data"
tar_file = cache_dir / "data.tar.bz2"
# Check if already extracted
if data_dir.exists() and not force:
expected_files = [
"train.jsonl",
"dev.jsonl",
"test.jsonl",
"train.db",
"dev.db",
"test.db",
]
if all((data_dir / f).exists() for f in expected_files):
return data_dir
# Download if needed
if not tar_file.exists() or force:
print(f"Downloading WikiSQL dataset from {WIKISQL_DATA_URL}...")
urlretrieve(WIKISQL_DATA_URL, tar_file)
print(f"Downloaded to {tar_file}")
# Extract
print(f"Extracting to {cache_dir}...")
with tarfile.open(tar_file, "r:bz2") as tar:
tar.extractall(cache_dir)
print("Extraction complete.")
return data_dir
def build_human_readable_sql(
table_header: List[str],
sel: int,
agg: int,
conds: List[Tuple[int, int, str]],
) -> str:
"""
Build a human-readable SQL query from WikiSQL format.
Args:
table_header: List of column names
sel: Index of selected column
agg: Index of aggregation operator
conds: List of (column_index, operator_index, condition) tuples
Returns:
Human-readable SQL query string
"""
# Build SELECT clause
sel_col = table_header[sel] if sel < len(table_header) else f"col{sel}"
if agg > 0:
select_clause = f"{AGG_OPS[agg]}({sel_col})"
else:
select_clause = sel_col
sql = f"SELECT {select_clause} FROM data"
# Build WHERE clause
if conds:
where_parts = []
for col_idx, op_idx, condition in conds:
col_name = (
table_header[col_idx]
if col_idx < len(table_header)
else f"col{col_idx}"
)
op = COND_OPS[op_idx] if op_idx < len(COND_OPS) else "="
# Quote string conditions
if isinstance(condition, str):
where_parts.append(f"{col_name} {op} '{condition}'")
else:
where_parts.append(f"{col_name} {op} {condition}")
sql += " WHERE " + " AND ".join(where_parts)
return sql
def load_wikisql_split(
split: str = "train",
cache_dir: Optional[Path] = None,
) -> List[Dict[str, Any]]:
"""
Load a WikiSQL dataset split.
Args:
split: One of 'train', 'dev', or 'test'
cache_dir: Directory containing the WikiSQL data
Returns:
List of dictionaries with question, table info, and gold SQL
"""
if cache_dir is None:
data_dir = download_wikisql()
else:
data_dir = cache_dir
jsonl_file = data_dir / f"{split}.jsonl"
tables_file = data_dir / f"{split}.tables.jsonl"
if not jsonl_file.exists():
raise FileNotFoundError(f"WikiSQL {split} JSONL file not found: {jsonl_file}")
if not tables_file.exists():
raise FileNotFoundError(f"WikiSQL {split} tables file not found: {tables_file}")
# Load all tables from tables.jsonl (has real column headers)
tables_cache = {}
with open(tables_file, "r", encoding="utf-8") as f:
for line in f:
table = json.loads(line.strip())
tables_cache[table["id"]] = {
"header": table["header"],
"rows": table["rows"],
"types": table.get("types", ["text"] * len(table["header"])),
}
print(f"Loaded {len(tables_cache)} tables for {split} split")
# Load questions and build dataset
dataset = []
skipped = 0
with open(jsonl_file, "r", encoding="utf-8") as f:
for line in f:
item = json.loads(line.strip())
table_info = tables_cache.get(item["table_id"])
if not table_info or not table_info["header"]:
skipped += 1
continue
# Build human-readable SQL
conds = [(c[0], c[1], c[2]) for c in item["sql"]["conds"]]
gold_sql = build_human_readable_sql(
table_info["header"],
item["sql"]["sel"],
item["sql"]["agg"],
conds,
)
dataset.append(
{
"question": item["question"],
"header": table_info["header"],
"rows": table_info["rows"],
"types": table_info["types"],
"gold_sql": gold_sql,
"table_id": item["table_id"],
}
)
if skipped > 0:
print(f"Skipped {skipped} examples with missing tables")
return dataset
def load_wikisql(cache_dir: Optional[Path] = None) -> Dict[str, List[Dict[str, Any]]]:
"""
Load the full WikiSQL dataset.
Returns:
Dictionary with 'train', 'dev', and 'test' splits
"""
return {
"train": load_wikisql_split("train", cache_dir),
"dev": load_wikisql_split("dev", cache_dir),
"test": load_wikisql_split("test", cache_dir),
}
if __name__ == "__main__":
# Test the loader
print("Testing WikiSQL loader...")
train = load_wikisql_split("train")
print(f"Loaded {len(train)} training examples")
if train:
print("\nFirst example:")
print(f" Question: {train[0]['question']}")
print(f" Header: {train[0]['header']}")
print(f" Gold SQL: {train[0]['gold_sql']}")
print(f" Rows (first 2): {train[0]['rows'][:2]}")