Integrate BLEUBERI as a submodule with direct import of reference-based reward functions.

This commit is contained in:
Allan Niemerg 2025-06-08 19:21:14 -05:00
parent 5bb5bd2c3d
commit a520f5f663
4 changed files with 270 additions and 70 deletions

View file

@ -7,8 +7,9 @@ using BLEU scores as rewards. Based on the paper:
https://arxiv.org/abs/2505.11080
"""
import logging
import os
import random
import sys
from typing import Dict, List, Optional, Tuple
import numpy as np
@ -17,14 +18,15 @@ from pydantic import Field
import wandb
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, Item
from .dataset_utils import (
aggregate_references,
load_tulu_dataset,
select_examples,
# Add the BLEUBERI repository to the Python path
_SUBMODULE_DIR = os.path.abspath(
os.path.join(os.path.dirname(__file__), "bleuberi-repo")
)
from .reward_functions import REWARD_FUNCTIONS
if _SUBMODULE_DIR not in sys.path:
sys.path.insert(0, _SUBMODULE_DIR)
logger = logging.getLogger(__name__)
# Import components directly from the BLEUBERI repository
from training.dataset import KeywordDataset # noqa: E402
class BLEUBERIEnvConfig(BaseEnvConfig):
@ -39,6 +41,18 @@ class BLEUBERIEnvConfig(BaseEnvConfig):
default="train",
description="Dataset split to use",
)
cache_dir: Optional[str] = Field(
default=None,
description="Cache directory for datasets and models",
)
streaming: bool = Field(
default=False,
description="Whether to stream the dataset",
)
shuffle: bool = Field(
default=True,
description="Whether to shuffle the dataset",
)
# Reference model configuration
ref_models: List[str] = Field(
@ -68,11 +82,18 @@ class BLEUBERIEnvConfig(BaseEnvConfig):
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider "
"the problem and deliberate with yourself via systematic reasoning processes to help come to "
"a correct solution prior to answering. You should enclose your thoughts and internal monologue "
"inside <think> </think> tags, and then provide your solution or response to the problem."
"inside <think> </think> tags, and then provide your solution or response to the problem. "
"After your thinking, make sure to clearly provide your final answer inside <answer></answer> tags."
),
description="System prompt for the model",
)
# Reasoning
reasoning: bool = Field(
default=True,
description="Whether to enable reasoning in the system prompt",
)
# Random seed
seed: int = Field(
default=42,
@ -115,17 +136,65 @@ class BLEUBERIEnv(BaseEnv):
random.seed(config.seed)
np.random.seed(config.seed)
# Initialize reward functions and metrics
self._init_metrics()
def _init_metrics(self):
"""Initialize BLEUBERI dataset for reward calculation."""
# Import logging here to avoid the unused import warning
import logging
self.logger = logging.getLogger(self.__class__.__name__)
# We'll initialize a KeywordDataset instance that will be used for reward calculation
# The path parameter is not important as we're only using the reward functions
self.bleuberi_dataset = KeywordDataset("", self.tokenizer)
# Initialize the metrics from BLEUBERI
self.bleu = self.bleuberi_dataset.bleu
self.rouge = self.bleuberi_dataset.rouge
self.bertscore = self.bleuberi_dataset.bertscore
self.logger.info("BLEUBERI reward metrics initialized")
async def setup(self):
"""Set up the environment, loading datasets and preparing examples."""
logger.info("Setting up BLEUBERI environment")
self.logger.info("Setting up BLEUBERI environment")
# Load dataset
self.dataset = load_tulu_dataset(
dataset_name=self.config.dataset_name,
dataset_split=self.config.dataset_split,
shuffle=True,
seed=self.config.seed,
)
try:
from datasets import load_dataset
self.dataset = load_dataset(
self.config.dataset_name,
split=self.config.dataset_split,
cache_dir=self.config.cache_dir,
streaming=self.config.streaming,
)
if self.config.shuffle and not self.config.streaming:
self.dataset = self.dataset.shuffle(seed=self.config.seed)
self.logger.info(f"Loaded dataset with {len(self.dataset)} examples")
except Exception as e:
self.logger.error(f"Error loading dataset: {e}")
# Create a small dummy dataset for testing
from datasets import Dataset
dummy_data = []
for i in range(10):
dummy_data.append(
{
"id": i,
"messages": [
{"role": "user", "content": f"Sample prompt {i}"},
{"role": "assistant", "content": f"Sample response {i}"},
],
"source": "dummy",
}
)
self.dataset = Dataset.from_list(dummy_data)
self.logger.info(f"Created dummy dataset with {len(self.dataset)} examples")
# Split into train and test (98% train, 2% test)
train_size = int(0.98 * len(self.dataset))
@ -133,35 +202,16 @@ class BLEUBERIEnv(BaseEnv):
train_dataset = self.dataset.select(range(train_size))
test_dataset = self.dataset.select(range(train_size, len(self.dataset)))
logger.info(
self.logger.info(
f"Split dataset into {len(train_dataset)} train and {len(test_dataset)} test examples"
)
# Aggregate references
self.train_aggregated = aggregate_references(
train_dataset, self.config.ref_models
)
self.test_aggregated = aggregate_references(
test_dataset, self.config.ref_models
)
self.train_examples = await self._aggregate_references(train_dataset)
self.test_examples = await self._aggregate_references(test_dataset)
# Select examples based on selection mode
self.train_examples = select_examples(
self.train_aggregated,
selection_mode=self.config.selection_mode,
num_examples=self.config.num_examples,
seed=self.config.seed,
)
self.test_examples = select_examples(
self.test_aggregated,
selection_mode="random",
num_examples=min(len(self.test_aggregated), 100), # Limit test set size
seed=self.config.seed,
)
logger.info(
f"Selected {len(self.train_examples)} train and {len(self.test_examples)} test examples"
self.logger.info(
f"Prepared {len(self.train_examples)} train and {len(self.test_examples)} test examples"
)
# Shuffle train examples
@ -170,10 +220,56 @@ class BLEUBERIEnv(BaseEnv):
self.train_index = 0
async def _aggregate_references(self, dataset):
"""
Aggregate references from the dataset based on specified reference models.
This is an async wrapper around the BLEUBERI aggregation logic.
"""
# Process examples to extract prompts and references
examples = []
for example in dataset:
example_id = example.get("id", "unknown_id")
# Get prompt from messages
prompt = None
if "messages" in example and example["messages"]:
for msg in example["messages"]:
if msg.get("role") == "user":
prompt = msg.get("content")
break
# Get ground truth from messages
ground_truth = None
if "messages" in example and example["messages"]:
for msg in example["messages"]:
if msg.get("role") == "assistant":
ground_truth = msg.get("content")
break
# Skip examples without prompt or ground truth
if not prompt or not ground_truth:
continue
# Create example
aggregated_example = {
"id": example_id,
"source": example.get("source", "unknown"),
"messages": example.get("messages", []),
"prompt": prompt,
"ground_truth": ground_truth,
"references": [ground_truth], # Using ground truth as reference
}
examples.append(aggregated_example)
return examples
async def get_next_item(self) -> Item:
"""Get the next example from the dataset."""
if not self.train_examples:
logger.warning("No train examples available")
self.logger.warning("No train examples available")
return None
# Cycle through the dataset
@ -183,8 +279,8 @@ class BLEUBERIEnv(BaseEnv):
# Format the prompt as a conversation
messages = []
# Add system message if provided
if self.config.system_prompt:
# Add system message if provided and reasoning is enabled
if self.config.system_prompt and self.config.reasoning:
messages.append({"role": "system", "content": self.config.system_prompt})
# Add user prompt
@ -207,6 +303,107 @@ class BLEUBERIEnv(BaseEnv):
return item
def _extract_answer(self, completion: str) -> str:
"""Extract the answer from a completion with potential thinking tags."""
# Use the extract_answer method from BLEUBERI's KeywordDataset
return KeywordDataset.extract_answer(self, completion)
async def _calculate_bleu_score(
self, response_content: str, references: List[str]
) -> float:
"""Calculate BLEU score for a response against references using BLEUBERI implementation."""
# Create a mock dataset instance to access the reward functions
dataset = KeywordDataset("", self.tokenizer)
# Prepare the inputs in the format expected by BLEUBERI
completion = response_content
kwargs = {"references": references}
# Use BLEUBERI's bleu_reward_func method
scores = dataset.bleu_reward_func([completion], **kwargs)
return scores[0] if scores else 0.0
async def _calculate_rouge_score(
self, response_content: str, references: List[str]
) -> float:
"""Calculate ROUGE score for a response against references using BLEUBERI implementation."""
# Create a mock dataset instance to access the reward functions
dataset = KeywordDataset("", self.tokenizer)
# Prepare the inputs in the format expected by BLEUBERI
completion = response_content
kwargs = {"references": references}
# Use BLEUBERI's rouge_reward_func method
scores = dataset.rouge_reward_func([completion], **kwargs)
return scores[0] if scores else 0.0
async def _calculate_bertscore(
self, response_content: str, references: List[str]
) -> float:
"""Calculate BERTScore for a response against references using BLEUBERI implementation."""
# Create a mock dataset instance to access the reward functions
dataset = KeywordDataset("", self.tokenizer)
# Prepare the inputs in the format expected by BLEUBERI
completion = response_content
kwargs = {"references": references}
# Use BLEUBERI's bertscore_reward_func method
scores = dataset.bertscore_reward_func([completion], **kwargs)
return scores[0] if scores else 0.0
async def _calculate_bleu_rouge_f1(
self, response_content: str, references: List[str]
) -> float:
"""Calculate F1 of BLEU and ROUGE scores using BLEUBERI implementation."""
# Create a mock dataset instance to access the reward functions
dataset = KeywordDataset("", self.tokenizer)
# Prepare the inputs in the format expected by BLEUBERI
completion = response_content
kwargs = {"references": references}
# Use BLEUBERI's bleu_rouge_f1_reward_func method
scores = dataset.bleu_rouge_f1_reward_func([completion], **kwargs)
return scores[0] if scores else 0.0
async def _calculate_reward(
self, response_content: str, references: List[str]
) -> float:
"""
Calculate the reward for a response based on the configured reward functions.
Uses BLEUBERI's reward functions directly.
"""
# Get the appropriate reward functions from BLEUBERI's implementation
reward_funcs = self.bleuberi_dataset.get_reward_funcs(self.config.reward_funcs)
if not reward_funcs:
self.logger.warning("No valid reward functions found")
return 0.0
# Calculate scores using each reward function
all_scores = []
for reward_func in reward_funcs:
# Apply the reward function
kwargs = {"references": references}
if (
hasattr(reward_func, "__name__")
and reward_func.__name__ == "rm_reward_func"
):
# RM reward function requires prompts
kwargs["prompts"] = [
[{"role": "user", "content": "prompt"}]
] # dummy prompt
scores = reward_func([response_content], **kwargs)
if scores and len(scores) > 0:
all_scores.append(scores[0])
# Take the average of all scores
final_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
return final_score
async def collect_trajectory(self, item: Item) -> Tuple[Dict, List[Item]]:
"""Generate a response and score it against references."""
backlog = []
@ -221,18 +418,8 @@ class BLEUBERIEnv(BaseEnv):
# Get references from item metadata
references = item.metadata.get("references", [])
# Calculate scores using the specified reward functions
scores = []
for reward_func_name in self.config.reward_funcs:
if reward_func_name in REWARD_FUNCTIONS:
reward_func = REWARD_FUNCTIONS[reward_func_name]
score = reward_func([response_content], [references])[0]
scores.append(score)
else:
logger.warning(f"Unknown reward function: {reward_func_name}")
# Take the average of all scores
final_score = sum(scores) / len(scores) if scores else 0.0
# Calculate score using the specified reward functions
final_score = await self._calculate_reward(response_content, references)
# Track whether the response was deemed correct (score > 0.5)
self.percent_correct_buffer.append(1.0 if final_score > 0.5 else 0.0)
@ -255,15 +442,15 @@ class BLEUBERIEnv(BaseEnv):
return scored_data, backlog
except Exception as e:
logger.error(f"Error in collect_trajectory: {e}")
self.logger.error(f"Error in collect_trajectory: {e}")
return None, backlog
async def evaluate(self):
"""Evaluate the model on the test set."""
logger.info("Starting evaluation")
self.logger.info("Starting evaluation")
if not self.test_examples:
logger.warning("No test examples available for evaluation")
self.logger.warning("No test examples available for evaluation")
return
# Track evaluation metrics
@ -278,7 +465,7 @@ class BLEUBERIEnv(BaseEnv):
for example in self.test_examples:
# Create messages
messages = []
if self.config.system_prompt:
if self.config.system_prompt and self.config.reasoning:
messages.append(
{"role": "system", "content": self.config.system_prompt}
)
@ -308,16 +495,8 @@ class BLEUBERIEnv(BaseEnv):
# Get references
references = example.get("references", [])
# Calculate scores
scores = []
for reward_func_name in self.config.reward_funcs:
if reward_func_name in REWARD_FUNCTIONS:
reward_func = REWARD_FUNCTIONS[reward_func_name]
score = reward_func([response_content], [references])[0]
scores.append(score)
# Take the average of all scores
final_score = sum(scores) / len(scores) if scores else 0.0
# Calculate score
final_score = await self._calculate_reward(response_content, references)
all_scores.append(final_score)
# Count as correct if score > 0.5
@ -336,7 +515,7 @@ class BLEUBERIEnv(BaseEnv):
total_count += 1
except Exception as e:
logger.error(f"Error in evaluation: {e}")
self.logger.error(f"Error in evaluation: {e}")
# Calculate evaluation metrics
accuracy = correct_count / total_count if total_count > 0 else 0
@ -351,7 +530,7 @@ class BLEUBERIEnv(BaseEnv):
}
await self.wandb_log(eval_metrics)
logger.info(f"Evaluation completed: Accuracy = {accuracy:.4f}")
self.logger.info(f"Evaluation completed: Accuracy = {accuracy:.4f}")
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log metrics to wandb."""