mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add automatic dataset load
This commit is contained in:
parent
013090579d
commit
0d54b3e83e
2 changed files with 103 additions and 35 deletions
|
|
@ -62,10 +62,10 @@ def download_gaia_dataset(output_dir, use_raw=False):
|
|||
ignore_patterns=[".gitattributes", "README.md"],
|
||||
)
|
||||
|
||||
# Create a minimal GAIA.py that loads directly from metadata.jsonl
|
||||
# Create a minimal GAIA.py that loads directly from metadata.jsonl using absolute paths
|
||||
with open(os.path.join(output_dir, "GAIA.py"), "w") as f:
|
||||
f.write(
|
||||
'''
|
||||
f'''
|
||||
"""
|
||||
GAIA benchmark dataset loader.
|
||||
Loads data directly from metadata.jsonl files.
|
||||
|
|
@ -75,6 +75,9 @@ import os
|
|||
import json
|
||||
import datasets
|
||||
|
||||
# Define absolute path to the dataset directory - crucial for correct operation
|
||||
DATASET_PATH = "{os.path.abspath(output_dir)}"
|
||||
|
||||
class GAIA(datasets.GeneratorBasedBuilder):
|
||||
VERSION = datasets.Version("2023.0.0")
|
||||
BUILDER_CONFIGS = [
|
||||
|
|
@ -88,13 +91,13 @@ class GAIA(datasets.GeneratorBasedBuilder):
|
|||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
description="GAIA benchmark dataset",
|
||||
features=datasets.Features({
|
||||
features=datasets.Features({{
|
||||
"Question": datasets.Value("string"),
|
||||
"Final answer": datasets.Value("string"),
|
||||
"Level": datasets.Value("string"),
|
||||
"task_id": datasets.Value("string"),
|
||||
"file_name": datasets.Value("string"),
|
||||
}),
|
||||
}}),
|
||||
homepage="https://huggingface.co/datasets/gaia-benchmark/GAIA",
|
||||
)
|
||||
|
||||
|
|
@ -102,30 +105,34 @@ class GAIA(datasets.GeneratorBasedBuilder):
|
|||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={"split": "validation"},
|
||||
gen_kwargs={{"split": "validation"}},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={"split": "test"},
|
||||
gen_kwargs={{"split": "test"}},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, split):
|
||||
"""Read data from the metadata.jsonl file."""
|
||||
# Data is stored in a 2023 subfolder
|
||||
metadata_path = os.path.join(os.path.dirname(__file__), "2023", split, "metadata.jsonl")
|
||||
# Use absolute path to the metadata.jsonl file
|
||||
metadata_path = os.path.join(DATASET_PATH, "2023", split, "metadata.jsonl")
|
||||
print(f"Loading GAIA data from: {{metadata_path}}")
|
||||
|
||||
with open(metadata_path, "r") as f:
|
||||
for i, line in enumerate(f):
|
||||
example = json.loads(line)
|
||||
if "file_name" in example and example["file_name"]:
|
||||
# Ensure file paths include the 2023 directory
|
||||
# Ensure file paths include the 2023 directory and absolute path
|
||||
example["file_name"] = os.path.join("2023", split, example["file_name"])
|
||||
yield i, example
|
||||
'''
|
||||
)
|
||||
|
||||
# Verify the download worked by checking for key files
|
||||
validation_path = os.path.join(output_dir, "2023", "validation", "metadata.jsonl")
|
||||
validation_path = os.path.join(
|
||||
output_dir, "2023", "validation", "metadata.jsonl"
|
||||
)
|
||||
if not os.path.exists(validation_path):
|
||||
logger.error(f"Download appears incomplete. Missing: {validation_path}")
|
||||
return False
|
||||
|
|
@ -145,7 +152,7 @@ def main():
|
|||
if success:
|
||||
logger.info("GAIA dataset setup completed successfully")
|
||||
logger.info(
|
||||
f"You can now run: python -m environments.smolagents_integration.run_gaia_single_task"
|
||||
"You can now run: python -m environments.smolagents_integration.run_gaia_single_task"
|
||||
)
|
||||
else:
|
||||
logger.error("GAIA dataset setup failed - Aborting")
|
||||
|
|
|
|||
|
|
@ -188,38 +188,99 @@ class SmolagentsEnv(BaseEnv):
|
|||
|
||||
# Load the GAIA dataset
|
||||
try:
|
||||
import datasets
|
||||
import os
|
||||
|
||||
logger.info(f"Loading GAIA dataset from {self.config.dataset_path}")
|
||||
self.dataset = datasets.load_dataset(
|
||||
f"{self.config.dataset_path}/GAIA.py",
|
||||
name="2023_all",
|
||||
split=self.config.split,
|
||||
# Check if dataset exists
|
||||
dataset_path = self.config.dataset_path
|
||||
validation_path = os.path.join(
|
||||
dataset_path, "2023", "validation", "metadata.jsonl"
|
||||
)
|
||||
gaia_py_path = os.path.join(dataset_path, "GAIA.py")
|
||||
|
||||
# If dataset files are missing, try to download them
|
||||
if not os.path.exists(validation_path) or not os.path.exists(gaia_py_path):
|
||||
logger.info(
|
||||
f"GAIA dataset not found at {dataset_path}, attempting to download..."
|
||||
)
|
||||
from .download_gaia import download_gaia_dataset
|
||||
|
||||
download_success = download_gaia_dataset(dataset_path)
|
||||
if not download_success:
|
||||
logger.error("Failed to download GAIA dataset automatically.")
|
||||
logger.error(
|
||||
"Please run: python -m environments.smolagents_integration.download_gaia"
|
||||
)
|
||||
self.examples = []
|
||||
return
|
||||
else:
|
||||
logger.info(
|
||||
f"GAIA dataset downloaded successfully to {dataset_path}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Loading GAIA dataset directly from {self.config.dataset_path}"
|
||||
)
|
||||
|
||||
# Convert to standard format
|
||||
self.examples = [
|
||||
{
|
||||
"question": example["Question"],
|
||||
"true_answer": example["Final answer"],
|
||||
"task": example["Level"],
|
||||
"task_id": (
|
||||
example["task_id"] if "task_id" in example else f"task_{i}"
|
||||
),
|
||||
"file_name": (
|
||||
f"{self.config.dataset_path}/{self.config.split}/{example['file_name']}"
|
||||
if example.get("file_name")
|
||||
else ""
|
||||
),
|
||||
}
|
||||
for i, example in enumerate(self.dataset)
|
||||
]
|
||||
# Load the metadata.jsonl file directly instead of using the datasets library
|
||||
import json
|
||||
|
||||
metadata_path = os.path.join(
|
||||
self.config.dataset_path, "2023", self.config.split, "metadata.jsonl"
|
||||
)
|
||||
|
||||
logger.info(f"Reading metadata from: {metadata_path}")
|
||||
|
||||
# Check if the file exists
|
||||
if not os.path.exists(metadata_path):
|
||||
logger.error(f"Metadata file not found: {metadata_path}")
|
||||
self.examples = []
|
||||
return
|
||||
|
||||
# Read the metadata file directly
|
||||
self.examples = []
|
||||
with open(metadata_path, "r") as f:
|
||||
for i, line in enumerate(f):
|
||||
try:
|
||||
example = json.loads(line)
|
||||
self.examples.append(
|
||||
{
|
||||
"question": example["Question"],
|
||||
"true_answer": example["Final answer"],
|
||||
"task": example["Level"],
|
||||
"task_id": (
|
||||
example["task_id"]
|
||||
if "task_id" in example
|
||||
else f"task_{i}"
|
||||
),
|
||||
"file_name": (
|
||||
os.path.join(
|
||||
self.config.dataset_path,
|
||||
"2023",
|
||||
self.config.split,
|
||||
example["file_name"],
|
||||
)
|
||||
if example.get("file_name")
|
||||
else ""
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as parse_error:
|
||||
logger.error(
|
||||
f"Error parsing line {i} of metadata file: {parse_error}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(self.examples)} examples from GAIA {self.config.split} set"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading GAIA dataset: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"Error loading GAIA dataset: {type(e).__name__}: {e}")
|
||||
logger.error(f"Detailed traceback: {traceback.format_exc()}")
|
||||
logger.error(
|
||||
"Please run: python -m environments.smolagents_integration.download_gaia"
|
||||
)
|
||||
# Create empty list if dataset loading fails
|
||||
self.examples = []
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue