Add BLEUBERI environment for reference-based RL

This commit is contained in:
Allan Niemerg 2025-06-08 18:02:33 -05:00
parent 3f6015e622
commit 5bb5bd2c3d
7 changed files with 948 additions and 0 deletions

View file

@ -0,0 +1,295 @@
"""
Dataset utilities for the BLEUBERI environment.
"""
import logging
import random
from typing import Any, Dict, List, Optional
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset
logger = logging.getLogger(__name__)
def load_tulu_dataset(
dataset_name: str = "allenai/tulu-3-sft-mixture",
dataset_split: str = "train",
cache_dir: Optional[str] = None,
streaming: bool = False,
shuffle: bool = True,
seed: int = 42,
) -> Dataset:
"""
Load the Tulu dataset from Hugging Face.
Args:
dataset_name: Name of the dataset on Hugging Face
dataset_split: Dataset split to load (train, validation, test)
cache_dir: Directory to cache the dataset
streaming: Whether to stream the dataset
shuffle: Whether to shuffle the dataset
seed: Random seed for shuffling
Returns:
Loaded dataset
"""
try:
if dataset_name.startswith("allenai/") or "/" in dataset_name:
ds = load_dataset(
dataset_name,
split=dataset_split,
cache_dir=cache_dir,
streaming=streaming,
)
else:
# Assume it's a local path
loaded_ds = load_dataset(dataset_name, cache_dir=cache_dir)
if isinstance(loaded_ds, DatasetDict):
ds = loaded_ds[dataset_split]
else:
ds = loaded_ds
logger.info(f"Loaded dataset with {len(ds)} examples")
if shuffle and not streaming:
ds = ds.shuffle(seed=seed)
return ds
except Exception as e:
logger.error(f"Error loading dataset: {e}")
# Create a small dummy dataset for testing
dummy_data = []
for i in range(10):
dummy_data.append(
{
"id": i,
"messages": [
{"role": "user", "content": f"Sample prompt {i}"},
{"role": "assistant", "content": f"Sample response {i}"},
],
"source": "dummy",
}
)
return Dataset.from_list(dummy_data)
def get_user_prompt_from_messages(
messages: List[Dict[str, str]], example_id: Any = None
) -> Optional[str]:
"""Extract the user prompt from a list of messages."""
if not messages:
if example_id:
logger.warning(f"Messages list is empty for example {example_id}.")
return None
for item in messages:
if item.get("role") == "user":
return item.get("content")
if example_id:
logger.warning(f"No user prompt found in messages for example {example_id}.")
return None
def get_assistant_response_from_messages(
messages: List[Dict[str, str]], example_id: Any = None
) -> Optional[str]:
"""Extract the assistant response from a list of messages."""
if not messages:
if example_id:
logger.warning(f"Messages list is empty for example {example_id}.")
return None
for item in messages:
if item.get("role") == "assistant":
return item.get("content")
if example_id:
logger.warning(
f"No assistant response found in messages for example {example_id}."
)
return None
def aggregate_references(
dataset: Dataset,
ref_models: List[str] = ["gold"],
) -> List[Dict[str, Any]]:
"""
Aggregate references from the dataset based on specified reference models.
Args:
dataset: Input dataset
ref_models: List of reference model names (or "gold" for ground truth)
Returns:
List of dictionaries with references aggregated
"""
logger.info(f"Aggregating data from reference models: {ref_models}")
# Check if "gold" (ground truth) is included in ref_models
use_gold = "gold" in ref_models
models_to_use = ref_models.copy()
if use_gold:
models_to_use.remove("gold")
# Create mapping from model names to column names
model_column_mapping = {}
available_ref_columns = [
col for col in dataset.column_names if col.startswith("ref_output_")
]
logger.info(f"Available reference output columns: {available_ref_columns}")
for model in models_to_use:
expected_column = f"ref_output_{model}"
if expected_column in available_ref_columns:
model_column_mapping[model] = expected_column
logger.info(f"Mapped model '{model}' to column '{expected_column}'")
else:
logger.warning(
f"Could not find column '{expected_column}' for model '{model}'"
)
# Only keep models that have corresponding columns
models_to_use = [model for model in models_to_use if model in model_column_mapping]
if not models_to_use and not use_gold:
raise ValueError("No reference models could be mapped to dataset columns")
# Filter examples that have all references
def has_all_references(example):
if use_gold and (
example.get("ref_output_gold") is None
or pd.isna(example.get("ref_output_gold"))
or str(example.get("ref_output_gold")).strip() == ""
):
return False
for model in models_to_use:
col_name = model_column_mapping[model]
if (
example.get(col_name) is None
or pd.isna(example.get(col_name))
or str(example.get(col_name)).strip() == ""
):
return False
return True
dataset_filtered = dataset.filter(has_all_references)
logger.info(
f"After filtering for complete references: {len(dataset_filtered)} examples"
)
# Aggregate data
aggregated_data = []
for example in dataset_filtered:
example_id = example.get("id", "unknown_id")
# Get prompt from "prompt" field or messages
if "prompt" in example and example["prompt"] is not None:
prompt = example["prompt"]
else:
prompt = get_user_prompt_from_messages(example.get("messages"), example_id)
# Get ground truth from ref_output_gold or messages
if "ref_output_gold" in example and example["ref_output_gold"] is not None:
ground_truth = example["ref_output_gold"]
else:
ground_truth = get_assistant_response_from_messages(
example.get("messages"), example_id
)
# Collect references
references = []
if use_gold:
references.append(ground_truth)
for model in models_to_use:
col_name = model_column_mapping[model]
references.append(example[col_name])
# Create aggregated example
aggregated_example = {
"id": example_id,
"source": example.get("source", "unknown"),
"messages": example.get("messages", []),
"prompt": prompt,
"ground_truth": ground_truth,
"references": references,
}
aggregated_data.append(aggregated_example)
logger.info(f"Aggregated {len(aggregated_data)} examples with specified references")
return aggregated_data
def select_examples(
data: List[Dict[str, Any]],
selection_mode: str = "random",
num_examples: Optional[int] = None,
score_field: Optional[str] = None,
seed: int = 42,
) -> List[Dict[str, Any]]:
"""
Select examples based on the specified mode.
Args:
data: List of examples with scores
selection_mode: Mode for selection (random, easy, medium, hard)
num_examples: Number of examples to select (if None, selects all)
score_field: Field name for scores (needed for easy, medium, hard modes)
seed: Random seed for selection
Returns:
List of selected examples
"""
if not data:
return []
if selection_mode == "random":
if num_examples and len(data) > num_examples:
random.seed(seed)
indices = random.sample(range(len(data)), num_examples)
return [data[i] for i in indices]
else:
return data
elif selection_mode in ["easy", "medium", "hard"]:
if not score_field:
logger.warning(
f"Score field not provided for {selection_mode} selection mode. Using random selection."
)
return select_examples(data, "random", num_examples, None, seed)
# Sort based on scores
if all(score_field in example for example in data):
sorted_data = sorted(
data, key=lambda x: x[score_field], reverse=(selection_mode == "easy")
)
if num_examples and len(sorted_data) > num_examples:
if selection_mode == "medium":
# Select from the middle
start_idx = (len(sorted_data) - num_examples) // 2
return sorted_data[start_idx : start_idx + num_examples]
else:
# Select from the beginning (for both easy and hard, just sorted differently)
return sorted_data[:num_examples]
else:
return sorted_data
else:
logger.warning(
f"Not all examples have score field '{score_field}'. Using random selection."
)
return select_examples(data, "random", num_examples, None, seed)
else:
logger.warning(
f"Unknown selection mode: {selection_mode}. Using random selection."
)
return select_examples(data, "random", num_examples, None, seed)