atropos/environments/t1_tool_planning/t1_data.py
2026-03-03 19:51:30 -06:00

168 lines
5.1 KiB
Python

"""
T1 dataset loader — downloads and parses the capitalone/T1 HuggingFace dataset.
Conversations are returned in the format expected by t1_core:
[{"Role": "assistant"|"user", "Filled_Template": str, "Filled_Plan": str}, ...]
"""
import logging
import random
from typing import Dict, List, Optional, Tuple
import pandas as pd
from huggingface_hub import hf_hub_download, list_repo_tree
logger = logging.getLogger(__name__)
REPO_ID = "capitalone/T1"
# Single-domain only for now (simpler tool defs, shorter conversations)
SINGLE_DOMAINS = ["hotel", "flight", "restaurant", "attraction"]
MULTI_DOMAINS = [
"flighthotel",
"hotelrestaurant",
"hotelattraction",
"flighthotelrestaurant",
"flighthotelattraction",
]
ALL_DOMAINS = SINGLE_DOMAINS + MULTI_DOMAINS
def _parse_role(filled_template: str) -> Tuple[str, str]:
"""Extract role and content from 'role: content' format."""
if filled_template.startswith("assistant:"):
return "assistant", filled_template[len("assistant:") :].strip()
elif filled_template.startswith("user:"):
return "user", filled_template[len("user:") :].strip()
else:
# Fallback — try to guess
return "assistant", filled_template.strip()
def _csv_to_conversations(path: str) -> Dict[int, List[dict]]:
"""Parse a T1 CSV file into a dict of conversation_id → turns."""
df = pd.read_csv(path)
conversations = {}
for conv_id, group in df.groupby("ID"):
turns = []
for _, row in group.iterrows():
template = row["Filled_Template"]
plan = row["Filled_Plan"]
# Skip NaN/empty templates
if pd.isna(template) or str(template).strip() in ("", "nan"):
continue
template = str(template)
role, content = _parse_role(template)
# Skip empty content after role extraction
if not content.strip():
continue
# NaN plans become empty string
if pd.isna(plan) or str(plan).strip() == "nan":
plan = ""
else:
plan = str(plan)
turns.append(
{
"Role": role,
"Filled_Template": content,
"Filled_Plan": plan,
}
)
# Only keep conversations with at least one user turn
if turns and any(t["Role"] == "user" for t in turns):
conversations[int(conv_id)] = turns
return conversations
def load_t1_dataset(
domains: Optional[List[str]] = None,
split: str = "train",
max_files_per_domain: Optional[int] = None,
cache_dir: Optional[str] = None,
) -> List[List[dict]]:
"""Load T1 conversations from HuggingFace.
Args:
domains: List of domains to load (default: single-domain only)
split: "train", "test", or "validation"
max_files_per_domain: Limit files per domain (each has 25, ~15 convos each)
cache_dir: HF cache directory
Returns:
List of conversations, each a list of turn dicts
"""
if domains is None:
domains = SINGLE_DOMAINS
all_conversations = []
# List all CSV files for the requested domains/split
repo_files = list(list_repo_tree(REPO_ID, repo_type="dataset", recursive=True))
csv_files = [
f.path for f in repo_files if hasattr(f, "size") and f.path.endswith(".csv")
]
for domain in domains:
prefix = f"{domain}/{split}/"
domain_files = sorted([f for f in csv_files if f.startswith(prefix)])
if max_files_per_domain:
domain_files = domain_files[:max_files_per_domain]
logger.info(f"Loading {len(domain_files)} files from {domain}/{split}")
for file_path in domain_files:
kwargs = {}
if cache_dir:
kwargs["cache_dir"] = cache_dir
local_path = hf_hub_download(
REPO_ID, file_path, repo_type="dataset", **kwargs
)
convos = _csv_to_conversations(local_path)
all_conversations.extend(convos.values())
logger.info(f"Loaded {len(all_conversations)} conversations total")
return all_conversations
def load_t1_split(
domains: Optional[List[str]] = None,
max_files_per_domain: Optional[int] = None,
eval_ratio: float = 0.1,
seed: int = 42,
) -> Tuple[List[List[dict]], List[List[dict]]]:
"""Load T1 train conversations and split into train/eval.
Args:
domains: Domains to load
max_files_per_domain: Limit files per domain
eval_ratio: Fraction of conversations for eval
seed: Random seed for split
Returns:
(train_conversations, eval_conversations)
"""
conversations = load_t1_dataset(
domains=domains,
split="train",
max_files_per_domain=max_files_per_domain,
)
rng = random.Random(seed)
rng.shuffle(conversations)
n_eval = max(1, int(len(conversations) * eval_ratio))
eval_convos = conversations[:n_eval]
train_convos = conversations[n_eval:]
logger.info(f"Split: {len(train_convos)} train, {len(eval_convos)} eval")
return train_convos, eval_convos