add automatic dataset load

This commit is contained in:
Allan Niemerg 2025-05-27 11:57:17 -05:00
parent 013090579d
commit 0d54b3e83e
2 changed files with 103 additions and 35 deletions

View file

@ -188,38 +188,99 @@ class SmolagentsEnv(BaseEnv):
# Load the GAIA dataset
try:
import datasets
import os
logger.info(f"Loading GAIA dataset from {self.config.dataset_path}")
self.dataset = datasets.load_dataset(
f"{self.config.dataset_path}/GAIA.py",
name="2023_all",
split=self.config.split,
# Check if dataset exists
dataset_path = self.config.dataset_path
validation_path = os.path.join(
dataset_path, "2023", "validation", "metadata.jsonl"
)
gaia_py_path = os.path.join(dataset_path, "GAIA.py")
# If dataset files are missing, try to download them
if not os.path.exists(validation_path) or not os.path.exists(gaia_py_path):
logger.info(
f"GAIA dataset not found at {dataset_path}, attempting to download..."
)
from .download_gaia import download_gaia_dataset
download_success = download_gaia_dataset(dataset_path)
if not download_success:
logger.error("Failed to download GAIA dataset automatically.")
logger.error(
"Please run: python -m environments.smolagents_integration.download_gaia"
)
self.examples = []
return
else:
logger.info(
f"GAIA dataset downloaded successfully to {dataset_path}"
)
logger.info(
f"Loading GAIA dataset directly from {self.config.dataset_path}"
)
# Convert to standard format
self.examples = [
{
"question": example["Question"],
"true_answer": example["Final answer"],
"task": example["Level"],
"task_id": (
example["task_id"] if "task_id" in example else f"task_{i}"
),
"file_name": (
f"{self.config.dataset_path}/{self.config.split}/{example['file_name']}"
if example.get("file_name")
else ""
),
}
for i, example in enumerate(self.dataset)
]
# Load the metadata.jsonl file directly instead of using the datasets library
import json
metadata_path = os.path.join(
self.config.dataset_path, "2023", self.config.split, "metadata.jsonl"
)
logger.info(f"Reading metadata from: {metadata_path}")
# Check if the file exists
if not os.path.exists(metadata_path):
logger.error(f"Metadata file not found: {metadata_path}")
self.examples = []
return
# Read the metadata file directly
self.examples = []
with open(metadata_path, "r") as f:
for i, line in enumerate(f):
try:
example = json.loads(line)
self.examples.append(
{
"question": example["Question"],
"true_answer": example["Final answer"],
"task": example["Level"],
"task_id": (
example["task_id"]
if "task_id" in example
else f"task_{i}"
),
"file_name": (
os.path.join(
self.config.dataset_path,
"2023",
self.config.split,
example["file_name"],
)
if example.get("file_name")
else ""
),
}
)
except Exception as parse_error:
logger.error(
f"Error parsing line {i} of metadata file: {parse_error}"
)
continue
logger.info(
f"Loaded {len(self.examples)} examples from GAIA {self.config.split} set"
)
except Exception as e:
logger.error(f"Error loading GAIA dataset: {e}")
import traceback
logger.error(f"Error loading GAIA dataset: {type(e).__name__}: {e}")
logger.error(f"Detailed traceback: {traceback.format_exc()}")
logger.error(
"Please run: python -m environments.smolagents_integration.download_gaia"
)
# Create empty list if dataset loading fails
self.examples = []