mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
168 lines
5.1 KiB
Python
168 lines
5.1 KiB
Python
"""
|
|
T1 dataset loader — downloads and parses the capitalone/T1 HuggingFace dataset.
|
|
|
|
Conversations are returned in the format expected by t1_core:
|
|
[{"Role": "assistant"|"user", "Filled_Template": str, "Filled_Plan": str}, ...]
|
|
"""
|
|
|
|
import logging
|
|
import random
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import pandas as pd
|
|
from huggingface_hub import hf_hub_download, list_repo_tree
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
REPO_ID = "capitalone/T1"
|
|
|
|
# Single-domain only for now (simpler tool defs, shorter conversations)
|
|
SINGLE_DOMAINS = ["hotel", "flight", "restaurant", "attraction"]
|
|
MULTI_DOMAINS = [
|
|
"flighthotel",
|
|
"hotelrestaurant",
|
|
"hotelattraction",
|
|
"flighthotelrestaurant",
|
|
"flighthotelattraction",
|
|
]
|
|
ALL_DOMAINS = SINGLE_DOMAINS + MULTI_DOMAINS
|
|
|
|
|
|
def _parse_role(filled_template: str) -> Tuple[str, str]:
|
|
"""Extract role and content from 'role: content' format."""
|
|
if filled_template.startswith("assistant:"):
|
|
return "assistant", filled_template[len("assistant:") :].strip()
|
|
elif filled_template.startswith("user:"):
|
|
return "user", filled_template[len("user:") :].strip()
|
|
else:
|
|
# Fallback — try to guess
|
|
return "assistant", filled_template.strip()
|
|
|
|
|
|
def _csv_to_conversations(path: str) -> Dict[int, List[dict]]:
|
|
"""Parse a T1 CSV file into a dict of conversation_id → turns."""
|
|
df = pd.read_csv(path)
|
|
conversations = {}
|
|
|
|
for conv_id, group in df.groupby("ID"):
|
|
turns = []
|
|
for _, row in group.iterrows():
|
|
template = row["Filled_Template"]
|
|
plan = row["Filled_Plan"]
|
|
|
|
# Skip NaN/empty templates
|
|
if pd.isna(template) or str(template).strip() in ("", "nan"):
|
|
continue
|
|
|
|
template = str(template)
|
|
role, content = _parse_role(template)
|
|
|
|
# Skip empty content after role extraction
|
|
if not content.strip():
|
|
continue
|
|
|
|
# NaN plans become empty string
|
|
if pd.isna(plan) or str(plan).strip() == "nan":
|
|
plan = ""
|
|
else:
|
|
plan = str(plan)
|
|
|
|
turns.append(
|
|
{
|
|
"Role": role,
|
|
"Filled_Template": content,
|
|
"Filled_Plan": plan,
|
|
}
|
|
)
|
|
|
|
# Only keep conversations with at least one user turn
|
|
if turns and any(t["Role"] == "user" for t in turns):
|
|
conversations[int(conv_id)] = turns
|
|
|
|
return conversations
|
|
|
|
|
|
def load_t1_dataset(
|
|
domains: Optional[List[str]] = None,
|
|
split: str = "train",
|
|
max_files_per_domain: Optional[int] = None,
|
|
cache_dir: Optional[str] = None,
|
|
) -> List[List[dict]]:
|
|
"""Load T1 conversations from HuggingFace.
|
|
|
|
Args:
|
|
domains: List of domains to load (default: single-domain only)
|
|
split: "train", "test", or "validation"
|
|
max_files_per_domain: Limit files per domain (each has 25, ~15 convos each)
|
|
cache_dir: HF cache directory
|
|
|
|
Returns:
|
|
List of conversations, each a list of turn dicts
|
|
"""
|
|
if domains is None:
|
|
domains = SINGLE_DOMAINS
|
|
|
|
all_conversations = []
|
|
|
|
# List all CSV files for the requested domains/split
|
|
repo_files = list(list_repo_tree(REPO_ID, repo_type="dataset", recursive=True))
|
|
csv_files = [
|
|
f.path for f in repo_files if hasattr(f, "size") and f.path.endswith(".csv")
|
|
]
|
|
|
|
for domain in domains:
|
|
prefix = f"{domain}/{split}/"
|
|
domain_files = sorted([f for f in csv_files if f.startswith(prefix)])
|
|
|
|
if max_files_per_domain:
|
|
domain_files = domain_files[:max_files_per_domain]
|
|
|
|
logger.info(f"Loading {len(domain_files)} files from {domain}/{split}")
|
|
|
|
for file_path in domain_files:
|
|
kwargs = {}
|
|
if cache_dir:
|
|
kwargs["cache_dir"] = cache_dir
|
|
|
|
local_path = hf_hub_download(
|
|
REPO_ID, file_path, repo_type="dataset", **kwargs
|
|
)
|
|
convos = _csv_to_conversations(local_path)
|
|
all_conversations.extend(convos.values())
|
|
|
|
logger.info(f"Loaded {len(all_conversations)} conversations total")
|
|
return all_conversations
|
|
|
|
|
|
def load_t1_split(
|
|
domains: Optional[List[str]] = None,
|
|
max_files_per_domain: Optional[int] = None,
|
|
eval_ratio: float = 0.1,
|
|
seed: int = 42,
|
|
) -> Tuple[List[List[dict]], List[List[dict]]]:
|
|
"""Load T1 train conversations and split into train/eval.
|
|
|
|
Args:
|
|
domains: Domains to load
|
|
max_files_per_domain: Limit files per domain
|
|
eval_ratio: Fraction of conversations for eval
|
|
seed: Random seed for split
|
|
|
|
Returns:
|
|
(train_conversations, eval_conversations)
|
|
"""
|
|
conversations = load_t1_dataset(
|
|
domains=domains,
|
|
split="train",
|
|
max_files_per_domain=max_files_per_domain,
|
|
)
|
|
|
|
rng = random.Random(seed)
|
|
rng.shuffle(conversations)
|
|
|
|
n_eval = max(1, int(len(conversations) * eval_ratio))
|
|
eval_convos = conversations[:n_eval]
|
|
train_convos = conversations[n_eval:]
|
|
|
|
logger.info(f"Split: {len(train_convos)} train, {len(eval_convos)} eval")
|
|
return train_convos, eval_convos
|