mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-02 17:45:50 +00:00
add env using the tool api stuff
This commit is contained in:
parent
c8eb63f33d
commit
12d61d197f
15 changed files with 2632 additions and 21 deletions
168
environments/t1_tool_planning/t1_data.py
Normal file
168
environments/t1_tool_planning/t1_data.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
"""
|
||||
T1 dataset loader — downloads and parses the capitalone/T1 HuggingFace dataset.
|
||||
|
||||
Conversations are returned in the format expected by t1_core:
|
||||
[{"Role": "assistant"|"user", "Filled_Template": str, "Filled_Plan": str}, ...]
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from huggingface_hub import hf_hub_download, list_repo_tree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REPO_ID = "capitalone/T1"
|
||||
|
||||
# Single-domain only for now (simpler tool defs, shorter conversations)
|
||||
SINGLE_DOMAINS = ["hotel", "flight", "restaurant", "attraction"]
|
||||
MULTI_DOMAINS = [
|
||||
"flighthotel",
|
||||
"hotelrestaurant",
|
||||
"hotelattraction",
|
||||
"flighthotelrestaurant",
|
||||
"flighthotelattraction",
|
||||
]
|
||||
ALL_DOMAINS = SINGLE_DOMAINS + MULTI_DOMAINS
|
||||
|
||||
|
||||
def _parse_role(filled_template: str) -> Tuple[str, str]:
|
||||
"""Extract role and content from 'role: content' format."""
|
||||
if filled_template.startswith("assistant:"):
|
||||
return "assistant", filled_template[len("assistant:") :].strip()
|
||||
elif filled_template.startswith("user:"):
|
||||
return "user", filled_template[len("user:") :].strip()
|
||||
else:
|
||||
# Fallback — try to guess
|
||||
return "assistant", filled_template.strip()
|
||||
|
||||
|
||||
def _csv_to_conversations(path: str) -> Dict[int, List[dict]]:
|
||||
"""Parse a T1 CSV file into a dict of conversation_id → turns."""
|
||||
df = pd.read_csv(path)
|
||||
conversations = {}
|
||||
|
||||
for conv_id, group in df.groupby("ID"):
|
||||
turns = []
|
||||
for _, row in group.iterrows():
|
||||
template = row["Filled_Template"]
|
||||
plan = row["Filled_Plan"]
|
||||
|
||||
# Skip NaN/empty templates
|
||||
if pd.isna(template) or str(template).strip() in ("", "nan"):
|
||||
continue
|
||||
|
||||
template = str(template)
|
||||
role, content = _parse_role(template)
|
||||
|
||||
# Skip empty content after role extraction
|
||||
if not content.strip():
|
||||
continue
|
||||
|
||||
# NaN plans become empty string
|
||||
if pd.isna(plan) or str(plan).strip() == "nan":
|
||||
plan = ""
|
||||
else:
|
||||
plan = str(plan)
|
||||
|
||||
turns.append(
|
||||
{
|
||||
"Role": role,
|
||||
"Filled_Template": content,
|
||||
"Filled_Plan": plan,
|
||||
}
|
||||
)
|
||||
|
||||
# Only keep conversations with at least one user turn
|
||||
if turns and any(t["Role"] == "user" for t in turns):
|
||||
conversations[int(conv_id)] = turns
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def load_t1_dataset(
|
||||
domains: Optional[List[str]] = None,
|
||||
split: str = "train",
|
||||
max_files_per_domain: Optional[int] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> List[List[dict]]:
|
||||
"""Load T1 conversations from HuggingFace.
|
||||
|
||||
Args:
|
||||
domains: List of domains to load (default: single-domain only)
|
||||
split: "train", "test", or "validation"
|
||||
max_files_per_domain: Limit files per domain (each has 25, ~15 convos each)
|
||||
cache_dir: HF cache directory
|
||||
|
||||
Returns:
|
||||
List of conversations, each a list of turn dicts
|
||||
"""
|
||||
if domains is None:
|
||||
domains = SINGLE_DOMAINS
|
||||
|
||||
all_conversations = []
|
||||
|
||||
# List all CSV files for the requested domains/split
|
||||
repo_files = list(list_repo_tree(REPO_ID, repo_type="dataset", recursive=True))
|
||||
csv_files = [
|
||||
f.path for f in repo_files if hasattr(f, "size") and f.path.endswith(".csv")
|
||||
]
|
||||
|
||||
for domain in domains:
|
||||
prefix = f"{domain}/{split}/"
|
||||
domain_files = sorted([f for f in csv_files if f.startswith(prefix)])
|
||||
|
||||
if max_files_per_domain:
|
||||
domain_files = domain_files[:max_files_per_domain]
|
||||
|
||||
logger.info(f"Loading {len(domain_files)} files from {domain}/{split}")
|
||||
|
||||
for file_path in domain_files:
|
||||
kwargs = {}
|
||||
if cache_dir:
|
||||
kwargs["cache_dir"] = cache_dir
|
||||
|
||||
local_path = hf_hub_download(
|
||||
REPO_ID, file_path, repo_type="dataset", **kwargs
|
||||
)
|
||||
convos = _csv_to_conversations(local_path)
|
||||
all_conversations.extend(convos.values())
|
||||
|
||||
logger.info(f"Loaded {len(all_conversations)} conversations total")
|
||||
return all_conversations
|
||||
|
||||
|
||||
def load_t1_split(
|
||||
domains: Optional[List[str]] = None,
|
||||
max_files_per_domain: Optional[int] = None,
|
||||
eval_ratio: float = 0.1,
|
||||
seed: int = 42,
|
||||
) -> Tuple[List[List[dict]], List[List[dict]]]:
|
||||
"""Load T1 train conversations and split into train/eval.
|
||||
|
||||
Args:
|
||||
domains: Domains to load
|
||||
max_files_per_domain: Limit files per domain
|
||||
eval_ratio: Fraction of conversations for eval
|
||||
seed: Random seed for split
|
||||
|
||||
Returns:
|
||||
(train_conversations, eval_conversations)
|
||||
"""
|
||||
conversations = load_t1_dataset(
|
||||
domains=domains,
|
||||
split="train",
|
||||
max_files_per_domain=max_files_per_domain,
|
||||
)
|
||||
|
||||
rng = random.Random(seed)
|
||||
rng.shuffle(conversations)
|
||||
|
||||
n_eval = max(1, int(len(conversations) * eval_ratio))
|
||||
eval_convos = conversations[:n_eval]
|
||||
train_convos = conversations[n_eval:]
|
||||
|
||||
logger.info(f"Split: {len(train_convos)} train, {len(eval_convos)} eval")
|
||||
return train_convos, eval_convos
|
||||
Loading…
Add table
Add a link
Reference in a new issue