mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
424 lines
14 KiB
Python
424 lines
14 KiB
Python
"""
|
|
SQL Query Generation Environment for Atropos
|
|
|
|
Trains LLMs to generate correct SQL queries from natural language questions.
|
|
Uses the Salesforce/WikiSQL dataset with execution-based scoring.
|
|
"""
|
|
|
|
import random
|
|
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
|
|
|
|
from sql_executor import (
|
|
create_table_from_wikisql,
|
|
execute_sql,
|
|
extract_boxed_sql,
|
|
quote_identifiers_in_sql,
|
|
results_match,
|
|
)
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
from wikisql_loader import load_wikisql_split
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.type_definitions import Item
|
|
|
|
# System prompt following the established Atropos pattern
|
|
system_prompt = (
|
|
"You are a deep thinking AI, you may use extremely long chains of thought "
|
|
"to deeply consider the problem and deliberate with yourself via systematic "
|
|
"reasoning processes to help come to a correct solution prior to answering. "
|
|
"You should enclose your thoughts and internal monologue inside <think> </think> "
|
|
"tags, and then provide your solution or response to the problem.\n\n"
|
|
)
|
|
|
|
system_prompt += """You are a SQL expert. Given a table schema and a natural language question,
|
|
generate a SQL query that answers the question.
|
|
|
|
You are allocated a maximum of 1024 tokens, please strive to use less.
|
|
|
|
Provide your SQL query inside \\boxed{} like this: \\boxed{SELECT column FROM table WHERE condition}
|
|
|
|
Important:
|
|
- Use only the columns provided in the table schema
|
|
- The table is always named "data"
|
|
- Ensure your SQL syntax is valid SQLite
|
|
|
|
So please end your answer with \\boxed{your SQL query here}"""
|
|
|
|
|
|
class WikiSQLRow(TypedDict):
|
|
"""Type definition for a WikiSQL dataset row."""
|
|
|
|
question: str
|
|
header: List[str]
|
|
rows: List[List[Any]]
|
|
types: List[str]
|
|
gold_sql: str
|
|
|
|
|
|
def format_table_schema(
|
|
header: List[str], rows: List[List[Any]], max_rows: int = 3
|
|
) -> str:
|
|
"""Format table schema for the prompt."""
|
|
schema = f"Table: data\nColumns: {', '.join(header)}\n"
|
|
if rows:
|
|
schema += "Sample data:\n"
|
|
for row in rows[:max_rows]:
|
|
row_str = " | ".join(str(v) for v in row)
|
|
schema += f" {row_str}\n"
|
|
return schema
|
|
|
|
|
|
class SQLQueryEnv(BaseEnv):
|
|
"""
|
|
Environment for training LLMs to generate SQL queries.
|
|
|
|
Uses the Salesforce/WikiSQL dataset and verifies correctness
|
|
by executing SQL against in-memory SQLite databases.
|
|
"""
|
|
|
|
name = "sql_query"
|
|
|
|
def __init__(
|
|
self,
|
|
config: BaseEnvConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.percent_correct_buffer = list()
|
|
self.eval_metrics = list()
|
|
self.execution_success_buffer = list()
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
|
"""Initialize default configuration for the environment."""
|
|
env_config = BaseEnvConfig(
|
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
|
group_size=8,
|
|
use_wandb=True,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=1000,
|
|
batch_size=12,
|
|
steps_per_eval=100,
|
|
max_token_length=1024,
|
|
wandb_name="sql_query",
|
|
)
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
|
base_url="http://localhost:9001/v1",
|
|
api_key="x",
|
|
num_requests_for_eval=256,
|
|
),
|
|
]
|
|
return env_config, server_configs
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log custom metrics to WandB."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Log percent correct
|
|
try:
|
|
wandb_metrics["train/percent_correct"] = sum(
|
|
self.percent_correct_buffer
|
|
) / len(self.percent_correct_buffer)
|
|
except ZeroDivisionError:
|
|
pass
|
|
|
|
# Log execution success rate
|
|
try:
|
|
wandb_metrics["train/execution_success"] = sum(
|
|
self.execution_success_buffer
|
|
) / len(self.execution_success_buffer)
|
|
except ZeroDivisionError:
|
|
pass
|
|
|
|
self.percent_correct_buffer = list()
|
|
self.execution_success_buffer = list()
|
|
|
|
for item in self.eval_metrics:
|
|
wandb_metrics[item[0]] = item[1]
|
|
self.eval_metrics = list()
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
async def setup(self):
|
|
"""Load the WikiSQL dataset and prepare train/test splits."""
|
|
# Load WikiSQL dataset directly from GitHub source
|
|
print("Loading WikiSQL training data...")
|
|
self.train = load_wikisql_split("train")
|
|
print(f"Loaded {len(self.train)} training examples")
|
|
|
|
print("Loading WikiSQL test data...")
|
|
self.test = load_wikisql_split("test")
|
|
print(f"Loaded {len(self.test)} test examples")
|
|
|
|
random.shuffle(self.train)
|
|
self.iter = 0
|
|
|
|
def save_checkpoint(self, step, data=None):
|
|
"""Save checkpoint with iteration state."""
|
|
if data is None:
|
|
data = {}
|
|
data["iter"] = self.iter
|
|
super().save_checkpoint(step, data)
|
|
|
|
def _score_sql(
|
|
self,
|
|
generated_sql: str,
|
|
gold_sql: str,
|
|
header: List[str],
|
|
rows: List[List[Any]],
|
|
) -> Tuple[float, bool]:
|
|
"""
|
|
Score SQL by execution comparison.
|
|
|
|
Returns:
|
|
Tuple of (score, execution_success)
|
|
"""
|
|
if not generated_sql:
|
|
return -1.0, False
|
|
|
|
# Create in-memory table
|
|
try:
|
|
conn = create_table_from_wikisql(header, rows)
|
|
except Exception:
|
|
return -1.0, False
|
|
|
|
# Quote identifiers in gold SQL that need quoting (e.g., "State/territory")
|
|
quoted_gold_sql = quote_identifiers_in_sql(gold_sql, header)
|
|
|
|
# Execute gold SQL
|
|
gold_result = execute_sql(conn, quoted_gold_sql)
|
|
if gold_result is None:
|
|
conn.close()
|
|
return -1.0, False
|
|
|
|
# Execute generated SQL
|
|
gen_result = execute_sql(conn, generated_sql)
|
|
conn.close()
|
|
|
|
if gen_result is None:
|
|
return -1.0, False
|
|
|
|
# Compare results
|
|
if results_match(gen_result, gold_result):
|
|
return 1.0, True
|
|
else:
|
|
return -1.0, True
|
|
|
|
async def rollout_and_score_eval(
|
|
self, question: str, gold_sql: str, header: List[str], rows: List[List[Any]]
|
|
) -> dict:
|
|
"""Rollout and score a single evaluation item."""
|
|
table_schema = format_table_schema(header, rows)
|
|
user_content = f"{table_schema}\nQuestion: {question}"
|
|
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
completion = await managed.chat_completion(
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_content},
|
|
],
|
|
n=1,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=0.6,
|
|
)
|
|
response_content = completion.choices[0].message.content
|
|
|
|
# Extract and score generated SQL
|
|
generated_sql = extract_boxed_sql(response_content)
|
|
score, exec_success = self._score_sql(generated_sql, gold_sql, header, rows)
|
|
correct = score == 1.0
|
|
|
|
sample = {
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_content},
|
|
{"role": "assistant", "content": response_content},
|
|
],
|
|
"question": question,
|
|
"gold_sql": gold_sql,
|
|
"generated_sql": generated_sql,
|
|
"score": 1 if correct else 0,
|
|
"correct": correct,
|
|
"execution_success": exec_success,
|
|
"finish_reason": completion.choices[0].finish_reason,
|
|
}
|
|
|
|
return {"score": 1 if correct else 0, "sample": sample}
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""Run evaluation on test set."""
|
|
import time
|
|
|
|
start_time = time.time()
|
|
|
|
eval_tasks = []
|
|
# Limit eval to first 500 items for efficiency
|
|
for item in self.test[:500]:
|
|
eval_tasks.append(
|
|
self.rollout_and_score_eval(
|
|
item["question"], item["gold_sql"], item["header"], item["rows"]
|
|
)
|
|
)
|
|
results = await tqdm_asyncio.gather(*eval_tasks)
|
|
|
|
scores = [result["score"] for result in results]
|
|
samples = [result["sample"] for result in results]
|
|
|
|
percent_correct = sum(scores) / len(scores) if scores else 0
|
|
|
|
end_time = time.time()
|
|
|
|
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
|
|
|
eval_metrics = {
|
|
"eval/percent_correct": percent_correct,
|
|
}
|
|
|
|
await self.evaluate_log(
|
|
metrics=eval_metrics,
|
|
samples=samples,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
generation_parameters={
|
|
"temperature": 0.6,
|
|
"max_tokens": self.config.max_token_length,
|
|
},
|
|
)
|
|
|
|
async def collect_trajectories(
|
|
self, item: WikiSQLRow
|
|
) -> Tuple[ScoredDataGroup, list[Item]]:
|
|
"""Generate SQL queries for a given question."""
|
|
table_schema = format_table_schema(item["header"], item["rows"])
|
|
user_content = f"{table_schema}\nQuestion: {item['question']}"
|
|
user_message = {"role": "user", "content": user_content}
|
|
|
|
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
chat_completions = await managed.chat_completion(
|
|
messages=[{"role": "system", "content": system_prompt}, user_message],
|
|
n=self.config.group_size,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=1.0,
|
|
)
|
|
|
|
state = managed.get_state()
|
|
nodes = state["nodes"]
|
|
|
|
to_score = list()
|
|
to_backlog = list()
|
|
|
|
for i, chat_completion in enumerate(chat_completions.choices):
|
|
messages = (
|
|
{"role": "system", "content": system_prompt},
|
|
user_message,
|
|
{"role": "assistant", "content": chat_completion.message.content},
|
|
)
|
|
to_score.append(
|
|
{
|
|
"messages": messages,
|
|
"gold_sql": item["gold_sql"],
|
|
"header": item["header"],
|
|
"rows": item["rows"],
|
|
"finish_reason": chat_completion.finish_reason,
|
|
"tokens": nodes[i].tokens,
|
|
"masks": nodes[i].masked_tokens,
|
|
"logprobs": nodes[i].logprobs,
|
|
}
|
|
)
|
|
|
|
to_postprocess = await self.score(to_score)
|
|
return to_postprocess, to_backlog
|
|
|
|
async def score(
|
|
self, rollout_group_data
|
|
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
|
"""Score generated SQL queries by execution comparison."""
|
|
scores = ScoredDataGroup()
|
|
scores["tokens"] = list()
|
|
scores["masks"] = list()
|
|
scores["scores"] = list()
|
|
scores["inference_logprobs"] = list()
|
|
|
|
# Get table info from first item
|
|
gold_sql = rollout_group_data[0]["gold_sql"]
|
|
header = rollout_group_data[0]["header"]
|
|
rows = rollout_group_data[0]["rows"]
|
|
|
|
random.shuffle(rollout_group_data)
|
|
|
|
for item in rollout_group_data:
|
|
response_content = item["messages"][-1]["content"]
|
|
|
|
# Extract SQL from response
|
|
generated_sql = extract_boxed_sql(response_content)
|
|
|
|
# Score by execution
|
|
reward, exec_success = self._score_sql(
|
|
generated_sql, gold_sql, header, rows
|
|
)
|
|
self.execution_success_buffer.append(1 if exec_success else 0)
|
|
|
|
tokens = item["tokens"]
|
|
masks = item["masks"]
|
|
logprobs = item["logprobs"]
|
|
|
|
# Remove obviously bad examples
|
|
if len([1 for i in masks if i != -100]) < 10:
|
|
continue
|
|
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["inference_logprobs"].append(logprobs)
|
|
scores["scores"].append(reward)
|
|
|
|
if len(scores["tokens"]) >= self.config.group_size:
|
|
break
|
|
|
|
for score in scores["scores"]:
|
|
self.percent_correct_buffer.append(max(score, 0))
|
|
|
|
# Check if all scores are the same
|
|
if all([score == 1 for score in scores["scores"]]):
|
|
# Apply length penalty when all are correct
|
|
token_lengths = [len(token) for token in scores["tokens"]]
|
|
if max(token_lengths) == 0:
|
|
return None
|
|
|
|
max_allowed_length = self.config.max_token_length
|
|
length_threshold = max_allowed_length * 0.5
|
|
|
|
scores["scores"] = []
|
|
for length in token_lengths:
|
|
if length <= length_threshold:
|
|
scores["scores"].append(1.0)
|
|
else:
|
|
percentage_of_range = (length - length_threshold) / (
|
|
max_allowed_length - length_threshold
|
|
)
|
|
percentage_of_range = min(percentage_of_range, 1.0)
|
|
scores["scores"].append(1.0 - percentage_of_range)
|
|
|
|
if all([scores["scores"][0] == score for score in scores["scores"]]):
|
|
return None # If all the same, return None
|
|
|
|
return scores
|
|
|
|
async def get_next_item(self) -> WikiSQLRow:
|
|
"""Get the next training item."""
|
|
next_item = self.train[self.iter % len(self.train)]
|
|
self.iter += 1
|
|
return next_item
|
|
|
|
|
|
if __name__ == "__main__":
|
|
SQLQueryEnv.cli()
|