mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
Integrate BLEUBERI as a submodule with direct import of reference-based reward functions.
This commit is contained in:
parent
5bb5bd2c3d
commit
a520f5f663
4 changed files with 270 additions and 70 deletions
|
|
@ -7,8 +7,9 @@ using BLEU scores as rewards. Based on the paper:
|
|||
https://arxiv.org/abs/2505.11080
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -17,14 +18,15 @@ from pydantic import Field
|
|||
import wandb
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, Item
|
||||
|
||||
from .dataset_utils import (
|
||||
aggregate_references,
|
||||
load_tulu_dataset,
|
||||
select_examples,
|
||||
# Add the BLEUBERI repository to the Python path
|
||||
_SUBMODULE_DIR = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "bleuberi-repo")
|
||||
)
|
||||
from .reward_functions import REWARD_FUNCTIONS
|
||||
if _SUBMODULE_DIR not in sys.path:
|
||||
sys.path.insert(0, _SUBMODULE_DIR)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Import components directly from the BLEUBERI repository
|
||||
from training.dataset import KeywordDataset # noqa: E402
|
||||
|
||||
|
||||
class BLEUBERIEnvConfig(BaseEnvConfig):
|
||||
|
|
@ -39,6 +41,18 @@ class BLEUBERIEnvConfig(BaseEnvConfig):
|
|||
default="train",
|
||||
description="Dataset split to use",
|
||||
)
|
||||
cache_dir: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Cache directory for datasets and models",
|
||||
)
|
||||
streaming: bool = Field(
|
||||
default=False,
|
||||
description="Whether to stream the dataset",
|
||||
)
|
||||
shuffle: bool = Field(
|
||||
default=True,
|
||||
description="Whether to shuffle the dataset",
|
||||
)
|
||||
|
||||
# Reference model configuration
|
||||
ref_models: List[str] = Field(
|
||||
|
|
@ -68,11 +82,18 @@ class BLEUBERIEnvConfig(BaseEnvConfig):
|
|||
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider "
|
||||
"the problem and deliberate with yourself via systematic reasoning processes to help come to "
|
||||
"a correct solution prior to answering. You should enclose your thoughts and internal monologue "
|
||||
"inside <think> </think> tags, and then provide your solution or response to the problem."
|
||||
"inside <think> </think> tags, and then provide your solution or response to the problem. "
|
||||
"After your thinking, make sure to clearly provide your final answer inside <answer></answer> tags."
|
||||
),
|
||||
description="System prompt for the model",
|
||||
)
|
||||
|
||||
# Reasoning
|
||||
reasoning: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable reasoning in the system prompt",
|
||||
)
|
||||
|
||||
# Random seed
|
||||
seed: int = Field(
|
||||
default=42,
|
||||
|
|
@ -115,17 +136,65 @@ class BLEUBERIEnv(BaseEnv):
|
|||
random.seed(config.seed)
|
||||
np.random.seed(config.seed)
|
||||
|
||||
# Initialize reward functions and metrics
|
||||
self._init_metrics()
|
||||
|
||||
def _init_metrics(self):
|
||||
"""Initialize BLEUBERI dataset for reward calculation."""
|
||||
# Import logging here to avoid the unused import warning
|
||||
import logging
|
||||
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
# We'll initialize a KeywordDataset instance that will be used for reward calculation
|
||||
# The path parameter is not important as we're only using the reward functions
|
||||
self.bleuberi_dataset = KeywordDataset("", self.tokenizer)
|
||||
|
||||
# Initialize the metrics from BLEUBERI
|
||||
self.bleu = self.bleuberi_dataset.bleu
|
||||
self.rouge = self.bleuberi_dataset.rouge
|
||||
self.bertscore = self.bleuberi_dataset.bertscore
|
||||
|
||||
self.logger.info("BLEUBERI reward metrics initialized")
|
||||
|
||||
async def setup(self):
|
||||
"""Set up the environment, loading datasets and preparing examples."""
|
||||
logger.info("Setting up BLEUBERI environment")
|
||||
self.logger.info("Setting up BLEUBERI environment")
|
||||
|
||||
# Load dataset
|
||||
self.dataset = load_tulu_dataset(
|
||||
dataset_name=self.config.dataset_name,
|
||||
dataset_split=self.config.dataset_split,
|
||||
shuffle=True,
|
||||
seed=self.config.seed,
|
||||
)
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
|
||||
self.dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
split=self.config.dataset_split,
|
||||
cache_dir=self.config.cache_dir,
|
||||
streaming=self.config.streaming,
|
||||
)
|
||||
|
||||
if self.config.shuffle and not self.config.streaming:
|
||||
self.dataset = self.dataset.shuffle(seed=self.config.seed)
|
||||
|
||||
self.logger.info(f"Loaded dataset with {len(self.dataset)} examples")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading dataset: {e}")
|
||||
# Create a small dummy dataset for testing
|
||||
from datasets import Dataset
|
||||
|
||||
dummy_data = []
|
||||
for i in range(10):
|
||||
dummy_data.append(
|
||||
{
|
||||
"id": i,
|
||||
"messages": [
|
||||
{"role": "user", "content": f"Sample prompt {i}"},
|
||||
{"role": "assistant", "content": f"Sample response {i}"},
|
||||
],
|
||||
"source": "dummy",
|
||||
}
|
||||
)
|
||||
self.dataset = Dataset.from_list(dummy_data)
|
||||
self.logger.info(f"Created dummy dataset with {len(self.dataset)} examples")
|
||||
|
||||
# Split into train and test (98% train, 2% test)
|
||||
train_size = int(0.98 * len(self.dataset))
|
||||
|
|
@ -133,35 +202,16 @@ class BLEUBERIEnv(BaseEnv):
|
|||
train_dataset = self.dataset.select(range(train_size))
|
||||
test_dataset = self.dataset.select(range(train_size, len(self.dataset)))
|
||||
|
||||
logger.info(
|
||||
self.logger.info(
|
||||
f"Split dataset into {len(train_dataset)} train and {len(test_dataset)} test examples"
|
||||
)
|
||||
|
||||
# Aggregate references
|
||||
self.train_aggregated = aggregate_references(
|
||||
train_dataset, self.config.ref_models
|
||||
)
|
||||
self.test_aggregated = aggregate_references(
|
||||
test_dataset, self.config.ref_models
|
||||
)
|
||||
self.train_examples = await self._aggregate_references(train_dataset)
|
||||
self.test_examples = await self._aggregate_references(test_dataset)
|
||||
|
||||
# Select examples based on selection mode
|
||||
self.train_examples = select_examples(
|
||||
self.train_aggregated,
|
||||
selection_mode=self.config.selection_mode,
|
||||
num_examples=self.config.num_examples,
|
||||
seed=self.config.seed,
|
||||
)
|
||||
|
||||
self.test_examples = select_examples(
|
||||
self.test_aggregated,
|
||||
selection_mode="random",
|
||||
num_examples=min(len(self.test_aggregated), 100), # Limit test set size
|
||||
seed=self.config.seed,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Selected {len(self.train_examples)} train and {len(self.test_examples)} test examples"
|
||||
self.logger.info(
|
||||
f"Prepared {len(self.train_examples)} train and {len(self.test_examples)} test examples"
|
||||
)
|
||||
|
||||
# Shuffle train examples
|
||||
|
|
@ -170,10 +220,56 @@ class BLEUBERIEnv(BaseEnv):
|
|||
|
||||
self.train_index = 0
|
||||
|
||||
async def _aggregate_references(self, dataset):
|
||||
"""
|
||||
Aggregate references from the dataset based on specified reference models.
|
||||
|
||||
This is an async wrapper around the BLEUBERI aggregation logic.
|
||||
"""
|
||||
# Process examples to extract prompts and references
|
||||
examples = []
|
||||
|
||||
for example in dataset:
|
||||
example_id = example.get("id", "unknown_id")
|
||||
|
||||
# Get prompt from messages
|
||||
prompt = None
|
||||
if "messages" in example and example["messages"]:
|
||||
for msg in example["messages"]:
|
||||
if msg.get("role") == "user":
|
||||
prompt = msg.get("content")
|
||||
break
|
||||
|
||||
# Get ground truth from messages
|
||||
ground_truth = None
|
||||
if "messages" in example and example["messages"]:
|
||||
for msg in example["messages"]:
|
||||
if msg.get("role") == "assistant":
|
||||
ground_truth = msg.get("content")
|
||||
break
|
||||
|
||||
# Skip examples without prompt or ground truth
|
||||
if not prompt or not ground_truth:
|
||||
continue
|
||||
|
||||
# Create example
|
||||
aggregated_example = {
|
||||
"id": example_id,
|
||||
"source": example.get("source", "unknown"),
|
||||
"messages": example.get("messages", []),
|
||||
"prompt": prompt,
|
||||
"ground_truth": ground_truth,
|
||||
"references": [ground_truth], # Using ground truth as reference
|
||||
}
|
||||
|
||||
examples.append(aggregated_example)
|
||||
|
||||
return examples
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""Get the next example from the dataset."""
|
||||
if not self.train_examples:
|
||||
logger.warning("No train examples available")
|
||||
self.logger.warning("No train examples available")
|
||||
return None
|
||||
|
||||
# Cycle through the dataset
|
||||
|
|
@ -183,8 +279,8 @@ class BLEUBERIEnv(BaseEnv):
|
|||
# Format the prompt as a conversation
|
||||
messages = []
|
||||
|
||||
# Add system message if provided
|
||||
if self.config.system_prompt:
|
||||
# Add system message if provided and reasoning is enabled
|
||||
if self.config.system_prompt and self.config.reasoning:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
# Add user prompt
|
||||
|
|
@ -207,6 +303,107 @@ class BLEUBERIEnv(BaseEnv):
|
|||
|
||||
return item
|
||||
|
||||
def _extract_answer(self, completion: str) -> str:
|
||||
"""Extract the answer from a completion with potential thinking tags."""
|
||||
# Use the extract_answer method from BLEUBERI's KeywordDataset
|
||||
return KeywordDataset.extract_answer(self, completion)
|
||||
|
||||
async def _calculate_bleu_score(
|
||||
self, response_content: str, references: List[str]
|
||||
) -> float:
|
||||
"""Calculate BLEU score for a response against references using BLEUBERI implementation."""
|
||||
# Create a mock dataset instance to access the reward functions
|
||||
dataset = KeywordDataset("", self.tokenizer)
|
||||
|
||||
# Prepare the inputs in the format expected by BLEUBERI
|
||||
completion = response_content
|
||||
kwargs = {"references": references}
|
||||
|
||||
# Use BLEUBERI's bleu_reward_func method
|
||||
scores = dataset.bleu_reward_func([completion], **kwargs)
|
||||
return scores[0] if scores else 0.0
|
||||
|
||||
async def _calculate_rouge_score(
|
||||
self, response_content: str, references: List[str]
|
||||
) -> float:
|
||||
"""Calculate ROUGE score for a response against references using BLEUBERI implementation."""
|
||||
# Create a mock dataset instance to access the reward functions
|
||||
dataset = KeywordDataset("", self.tokenizer)
|
||||
|
||||
# Prepare the inputs in the format expected by BLEUBERI
|
||||
completion = response_content
|
||||
kwargs = {"references": references}
|
||||
|
||||
# Use BLEUBERI's rouge_reward_func method
|
||||
scores = dataset.rouge_reward_func([completion], **kwargs)
|
||||
return scores[0] if scores else 0.0
|
||||
|
||||
async def _calculate_bertscore(
|
||||
self, response_content: str, references: List[str]
|
||||
) -> float:
|
||||
"""Calculate BERTScore for a response against references using BLEUBERI implementation."""
|
||||
# Create a mock dataset instance to access the reward functions
|
||||
dataset = KeywordDataset("", self.tokenizer)
|
||||
|
||||
# Prepare the inputs in the format expected by BLEUBERI
|
||||
completion = response_content
|
||||
kwargs = {"references": references}
|
||||
|
||||
# Use BLEUBERI's bertscore_reward_func method
|
||||
scores = dataset.bertscore_reward_func([completion], **kwargs)
|
||||
return scores[0] if scores else 0.0
|
||||
|
||||
async def _calculate_bleu_rouge_f1(
|
||||
self, response_content: str, references: List[str]
|
||||
) -> float:
|
||||
"""Calculate F1 of BLEU and ROUGE scores using BLEUBERI implementation."""
|
||||
# Create a mock dataset instance to access the reward functions
|
||||
dataset = KeywordDataset("", self.tokenizer)
|
||||
|
||||
# Prepare the inputs in the format expected by BLEUBERI
|
||||
completion = response_content
|
||||
kwargs = {"references": references}
|
||||
|
||||
# Use BLEUBERI's bleu_rouge_f1_reward_func method
|
||||
scores = dataset.bleu_rouge_f1_reward_func([completion], **kwargs)
|
||||
return scores[0] if scores else 0.0
|
||||
|
||||
async def _calculate_reward(
|
||||
self, response_content: str, references: List[str]
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the reward for a response based on the configured reward functions.
|
||||
Uses BLEUBERI's reward functions directly.
|
||||
"""
|
||||
# Get the appropriate reward functions from BLEUBERI's implementation
|
||||
reward_funcs = self.bleuberi_dataset.get_reward_funcs(self.config.reward_funcs)
|
||||
|
||||
if not reward_funcs:
|
||||
self.logger.warning("No valid reward functions found")
|
||||
return 0.0
|
||||
|
||||
# Calculate scores using each reward function
|
||||
all_scores = []
|
||||
for reward_func in reward_funcs:
|
||||
# Apply the reward function
|
||||
kwargs = {"references": references}
|
||||
if (
|
||||
hasattr(reward_func, "__name__")
|
||||
and reward_func.__name__ == "rm_reward_func"
|
||||
):
|
||||
# RM reward function requires prompts
|
||||
kwargs["prompts"] = [
|
||||
[{"role": "user", "content": "prompt"}]
|
||||
] # dummy prompt
|
||||
|
||||
scores = reward_func([response_content], **kwargs)
|
||||
if scores and len(scores) > 0:
|
||||
all_scores.append(scores[0])
|
||||
|
||||
# Take the average of all scores
|
||||
final_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
|
||||
return final_score
|
||||
|
||||
async def collect_trajectory(self, item: Item) -> Tuple[Dict, List[Item]]:
|
||||
"""Generate a response and score it against references."""
|
||||
backlog = []
|
||||
|
|
@ -221,18 +418,8 @@ class BLEUBERIEnv(BaseEnv):
|
|||
# Get references from item metadata
|
||||
references = item.metadata.get("references", [])
|
||||
|
||||
# Calculate scores using the specified reward functions
|
||||
scores = []
|
||||
for reward_func_name in self.config.reward_funcs:
|
||||
if reward_func_name in REWARD_FUNCTIONS:
|
||||
reward_func = REWARD_FUNCTIONS[reward_func_name]
|
||||
score = reward_func([response_content], [references])[0]
|
||||
scores.append(score)
|
||||
else:
|
||||
logger.warning(f"Unknown reward function: {reward_func_name}")
|
||||
|
||||
# Take the average of all scores
|
||||
final_score = sum(scores) / len(scores) if scores else 0.0
|
||||
# Calculate score using the specified reward functions
|
||||
final_score = await self._calculate_reward(response_content, references)
|
||||
|
||||
# Track whether the response was deemed correct (score > 0.5)
|
||||
self.percent_correct_buffer.append(1.0 if final_score > 0.5 else 0.0)
|
||||
|
|
@ -255,15 +442,15 @@ class BLEUBERIEnv(BaseEnv):
|
|||
return scored_data, backlog
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in collect_trajectory: {e}")
|
||||
self.logger.error(f"Error in collect_trajectory: {e}")
|
||||
return None, backlog
|
||||
|
||||
async def evaluate(self):
|
||||
"""Evaluate the model on the test set."""
|
||||
logger.info("Starting evaluation")
|
||||
self.logger.info("Starting evaluation")
|
||||
|
||||
if not self.test_examples:
|
||||
logger.warning("No test examples available for evaluation")
|
||||
self.logger.warning("No test examples available for evaluation")
|
||||
return
|
||||
|
||||
# Track evaluation metrics
|
||||
|
|
@ -278,7 +465,7 @@ class BLEUBERIEnv(BaseEnv):
|
|||
for example in self.test_examples:
|
||||
# Create messages
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
if self.config.system_prompt and self.config.reasoning:
|
||||
messages.append(
|
||||
{"role": "system", "content": self.config.system_prompt}
|
||||
)
|
||||
|
|
@ -308,16 +495,8 @@ class BLEUBERIEnv(BaseEnv):
|
|||
# Get references
|
||||
references = example.get("references", [])
|
||||
|
||||
# Calculate scores
|
||||
scores = []
|
||||
for reward_func_name in self.config.reward_funcs:
|
||||
if reward_func_name in REWARD_FUNCTIONS:
|
||||
reward_func = REWARD_FUNCTIONS[reward_func_name]
|
||||
score = reward_func([response_content], [references])[0]
|
||||
scores.append(score)
|
||||
|
||||
# Take the average of all scores
|
||||
final_score = sum(scores) / len(scores) if scores else 0.0
|
||||
# Calculate score
|
||||
final_score = await self._calculate_reward(response_content, references)
|
||||
all_scores.append(final_score)
|
||||
|
||||
# Count as correct if score > 0.5
|
||||
|
|
@ -336,7 +515,7 @@ class BLEUBERIEnv(BaseEnv):
|
|||
total_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in evaluation: {e}")
|
||||
self.logger.error(f"Error in evaluation: {e}")
|
||||
|
||||
# Calculate evaluation metrics
|
||||
accuracy = correct_count / total_count if total_count > 0 else 0
|
||||
|
|
@ -351,7 +530,7 @@ class BLEUBERIEnv(BaseEnv):
|
|||
}
|
||||
|
||||
await self.wandb_log(eval_metrics)
|
||||
logger.info(f"Evaluation completed: Accuracy = {accuracy:.4f}")
|
||||
self.logger.info(f"Evaluation completed: Accuracy = {accuracy:.4f}")
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log metrics to wandb."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue