mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-01 17:45:16 +00:00
Add BLEUBERI environment for reference-based RL
This commit is contained in:
parent
3f6015e622
commit
5bb5bd2c3d
7 changed files with 948 additions and 0 deletions
295
environments/bleuberi/dataset_utils.py
Normal file
295
environments/bleuberi/dataset_utils.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
"""
|
||||
Dataset utilities for the BLEUBERI environment.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from datasets import Dataset, DatasetDict, load_dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_tulu_dataset(
|
||||
dataset_name: str = "allenai/tulu-3-sft-mixture",
|
||||
dataset_split: str = "train",
|
||||
cache_dir: Optional[str] = None,
|
||||
streaming: bool = False,
|
||||
shuffle: bool = True,
|
||||
seed: int = 42,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Load the Tulu dataset from Hugging Face.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset on Hugging Face
|
||||
dataset_split: Dataset split to load (train, validation, test)
|
||||
cache_dir: Directory to cache the dataset
|
||||
streaming: Whether to stream the dataset
|
||||
shuffle: Whether to shuffle the dataset
|
||||
seed: Random seed for shuffling
|
||||
|
||||
Returns:
|
||||
Loaded dataset
|
||||
"""
|
||||
try:
|
||||
if dataset_name.startswith("allenai/") or "/" in dataset_name:
|
||||
ds = load_dataset(
|
||||
dataset_name,
|
||||
split=dataset_split,
|
||||
cache_dir=cache_dir,
|
||||
streaming=streaming,
|
||||
)
|
||||
else:
|
||||
# Assume it's a local path
|
||||
loaded_ds = load_dataset(dataset_name, cache_dir=cache_dir)
|
||||
if isinstance(loaded_ds, DatasetDict):
|
||||
ds = loaded_ds[dataset_split]
|
||||
else:
|
||||
ds = loaded_ds
|
||||
|
||||
logger.info(f"Loaded dataset with {len(ds)} examples")
|
||||
|
||||
if shuffle and not streaming:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
|
||||
return ds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading dataset: {e}")
|
||||
# Create a small dummy dataset for testing
|
||||
dummy_data = []
|
||||
for i in range(10):
|
||||
dummy_data.append(
|
||||
{
|
||||
"id": i,
|
||||
"messages": [
|
||||
{"role": "user", "content": f"Sample prompt {i}"},
|
||||
{"role": "assistant", "content": f"Sample response {i}"},
|
||||
],
|
||||
"source": "dummy",
|
||||
}
|
||||
)
|
||||
return Dataset.from_list(dummy_data)
|
||||
|
||||
|
||||
def get_user_prompt_from_messages(
|
||||
messages: List[Dict[str, str]], example_id: Any = None
|
||||
) -> Optional[str]:
|
||||
"""Extract the user prompt from a list of messages."""
|
||||
if not messages:
|
||||
if example_id:
|
||||
logger.warning(f"Messages list is empty for example {example_id}.")
|
||||
return None
|
||||
|
||||
for item in messages:
|
||||
if item.get("role") == "user":
|
||||
return item.get("content")
|
||||
|
||||
if example_id:
|
||||
logger.warning(f"No user prompt found in messages for example {example_id}.")
|
||||
return None
|
||||
|
||||
|
||||
def get_assistant_response_from_messages(
|
||||
messages: List[Dict[str, str]], example_id: Any = None
|
||||
) -> Optional[str]:
|
||||
"""Extract the assistant response from a list of messages."""
|
||||
if not messages:
|
||||
if example_id:
|
||||
logger.warning(f"Messages list is empty for example {example_id}.")
|
||||
return None
|
||||
|
||||
for item in messages:
|
||||
if item.get("role") == "assistant":
|
||||
return item.get("content")
|
||||
|
||||
if example_id:
|
||||
logger.warning(
|
||||
f"No assistant response found in messages for example {example_id}."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def aggregate_references(
|
||||
dataset: Dataset,
|
||||
ref_models: List[str] = ["gold"],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Aggregate references from the dataset based on specified reference models.
|
||||
|
||||
Args:
|
||||
dataset: Input dataset
|
||||
ref_models: List of reference model names (or "gold" for ground truth)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with references aggregated
|
||||
"""
|
||||
logger.info(f"Aggregating data from reference models: {ref_models}")
|
||||
|
||||
# Check if "gold" (ground truth) is included in ref_models
|
||||
use_gold = "gold" in ref_models
|
||||
models_to_use = ref_models.copy()
|
||||
if use_gold:
|
||||
models_to_use.remove("gold")
|
||||
|
||||
# Create mapping from model names to column names
|
||||
model_column_mapping = {}
|
||||
available_ref_columns = [
|
||||
col for col in dataset.column_names if col.startswith("ref_output_")
|
||||
]
|
||||
|
||||
logger.info(f"Available reference output columns: {available_ref_columns}")
|
||||
|
||||
for model in models_to_use:
|
||||
expected_column = f"ref_output_{model}"
|
||||
|
||||
if expected_column in available_ref_columns:
|
||||
model_column_mapping[model] = expected_column
|
||||
logger.info(f"Mapped model '{model}' to column '{expected_column}'")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not find column '{expected_column}' for model '{model}'"
|
||||
)
|
||||
|
||||
# Only keep models that have corresponding columns
|
||||
models_to_use = [model for model in models_to_use if model in model_column_mapping]
|
||||
if not models_to_use and not use_gold:
|
||||
raise ValueError("No reference models could be mapped to dataset columns")
|
||||
|
||||
# Filter examples that have all references
|
||||
def has_all_references(example):
|
||||
if use_gold and (
|
||||
example.get("ref_output_gold") is None
|
||||
or pd.isna(example.get("ref_output_gold"))
|
||||
or str(example.get("ref_output_gold")).strip() == ""
|
||||
):
|
||||
return False
|
||||
|
||||
for model in models_to_use:
|
||||
col_name = model_column_mapping[model]
|
||||
if (
|
||||
example.get(col_name) is None
|
||||
or pd.isna(example.get(col_name))
|
||||
or str(example.get(col_name)).strip() == ""
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
dataset_filtered = dataset.filter(has_all_references)
|
||||
logger.info(
|
||||
f"After filtering for complete references: {len(dataset_filtered)} examples"
|
||||
)
|
||||
|
||||
# Aggregate data
|
||||
aggregated_data = []
|
||||
for example in dataset_filtered:
|
||||
example_id = example.get("id", "unknown_id")
|
||||
|
||||
# Get prompt from "prompt" field or messages
|
||||
if "prompt" in example and example["prompt"] is not None:
|
||||
prompt = example["prompt"]
|
||||
else:
|
||||
prompt = get_user_prompt_from_messages(example.get("messages"), example_id)
|
||||
|
||||
# Get ground truth from ref_output_gold or messages
|
||||
if "ref_output_gold" in example and example["ref_output_gold"] is not None:
|
||||
ground_truth = example["ref_output_gold"]
|
||||
else:
|
||||
ground_truth = get_assistant_response_from_messages(
|
||||
example.get("messages"), example_id
|
||||
)
|
||||
|
||||
# Collect references
|
||||
references = []
|
||||
if use_gold:
|
||||
references.append(ground_truth)
|
||||
|
||||
for model in models_to_use:
|
||||
col_name = model_column_mapping[model]
|
||||
references.append(example[col_name])
|
||||
|
||||
# Create aggregated example
|
||||
aggregated_example = {
|
||||
"id": example_id,
|
||||
"source": example.get("source", "unknown"),
|
||||
"messages": example.get("messages", []),
|
||||
"prompt": prompt,
|
||||
"ground_truth": ground_truth,
|
||||
"references": references,
|
||||
}
|
||||
|
||||
aggregated_data.append(aggregated_example)
|
||||
|
||||
logger.info(f"Aggregated {len(aggregated_data)} examples with specified references")
|
||||
return aggregated_data
|
||||
|
||||
|
||||
def select_examples(
|
||||
data: List[Dict[str, Any]],
|
||||
selection_mode: str = "random",
|
||||
num_examples: Optional[int] = None,
|
||||
score_field: Optional[str] = None,
|
||||
seed: int = 42,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Select examples based on the specified mode.
|
||||
|
||||
Args:
|
||||
data: List of examples with scores
|
||||
selection_mode: Mode for selection (random, easy, medium, hard)
|
||||
num_examples: Number of examples to select (if None, selects all)
|
||||
score_field: Field name for scores (needed for easy, medium, hard modes)
|
||||
seed: Random seed for selection
|
||||
|
||||
Returns:
|
||||
List of selected examples
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
if selection_mode == "random":
|
||||
if num_examples and len(data) > num_examples:
|
||||
random.seed(seed)
|
||||
indices = random.sample(range(len(data)), num_examples)
|
||||
return [data[i] for i in indices]
|
||||
else:
|
||||
return data
|
||||
|
||||
elif selection_mode in ["easy", "medium", "hard"]:
|
||||
if not score_field:
|
||||
logger.warning(
|
||||
f"Score field not provided for {selection_mode} selection mode. Using random selection."
|
||||
)
|
||||
return select_examples(data, "random", num_examples, None, seed)
|
||||
|
||||
# Sort based on scores
|
||||
if all(score_field in example for example in data):
|
||||
sorted_data = sorted(
|
||||
data, key=lambda x: x[score_field], reverse=(selection_mode == "easy")
|
||||
)
|
||||
|
||||
if num_examples and len(sorted_data) > num_examples:
|
||||
if selection_mode == "medium":
|
||||
# Select from the middle
|
||||
start_idx = (len(sorted_data) - num_examples) // 2
|
||||
return sorted_data[start_idx : start_idx + num_examples]
|
||||
else:
|
||||
# Select from the beginning (for both easy and hard, just sorted differently)
|
||||
return sorted_data[:num_examples]
|
||||
else:
|
||||
return sorted_data
|
||||
else:
|
||||
logger.warning(
|
||||
f"Not all examples have score field '{score_field}'. Using random selection."
|
||||
)
|
||||
return select_examples(data, "random", num_examples, None, seed)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown selection mode: {selection_mode}. Using random selection."
|
||||
)
|
||||
return select_examples(data, "random", num_examples, None, seed)
|
||||
Loading…
Add table
Add a link
Reference in a new issue