diff --git a/environments/bleuberi/README.md b/environments/bleuberi/README.md index ef1841ff..6c18b88b 100644 --- a/environments/bleuberi/README.md +++ b/environments/bleuberi/README.md @@ -10,6 +10,22 @@ BLEUBERI uses BLEU scores (a simple n-gram matching metric) directly as rewards 2. Computes BLEU scores by comparing model outputs to these references 3. Uses these scores as rewards to train models through GRPO +## Installation + +Before using the BLEUBERI environment, you need to install its dependencies: + +```bash +# Install the required dependencies +pip install -r environments/bleuberi/requirements.txt +``` + +The key dependencies include: +- `model2vec`: For embedding-based similarity metrics +- `bert-score`: For semantic similarity evaluation +- `sacrebleu`: For BLEU score calculation +- `evaluate`: For evaluation metrics +- `datasets`: For dataset handling + ## Features - BLEU-based reward functions (with support for multiple reference models) diff --git a/environments/bleuberi/bleuberi_env.py b/environments/bleuberi/bleuberi_env.py index c9b74a66..60bf93cb 100644 --- a/environments/bleuberi/bleuberi_env.py +++ b/environments/bleuberi/bleuberi_env.py @@ -10,13 +10,24 @@ https://arxiv.org/abs/2505.11080 import os import random import sys -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np from pydantic import Field +from typing_extensions import TypedDict import wandb -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, Item +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataItem + + +# Define our own Item class for the environment +class BLEUBERIItem(TypedDict): + """Item for BLEUBERI environment""" + + id: str + messages: List[Dict[str, Any]] + metadata: Dict[str, Any] + # Add the BLEUBERI repository to the Python path _SUBMODULE_DIR = os.path.abspath( @@ -266,7 +277,7 @@ class BLEUBERIEnv(BaseEnv): return examples - async def get_next_item(self) -> Item: + async def get_next_item(self) -> BLEUBERIItem: """Get the next example from the dataset.""" if not self.train_examples: self.logger.warning("No train examples available") @@ -291,15 +302,15 @@ class BLEUBERIEnv(BaseEnv): messages.append({"role": "user", "content": user_prompt}) # Create item - item = Item( - messages=messages, - id=str(example.get("id", f"item_{self.train_index}")), - metadata={ + item: BLEUBERIItem = { + "messages": messages, + "id": str(example.get("id", f"item_{self.train_index}")), + "metadata": { "references": example.get("references", []), "source": example.get("source", "unknown"), "prompt": user_prompt, }, - ) + } return item @@ -404,19 +415,21 @@ class BLEUBERIEnv(BaseEnv): final_score = sum(all_scores) / len(all_scores) if all_scores else 0.0 return final_score - async def collect_trajectory(self, item: Item) -> Tuple[Dict, List[Item]]: + async def collect_trajectory( + self, item: BLEUBERIItem + ) -> Tuple[Optional[ScoredDataItem], List[BLEUBERIItem]]: """Generate a response and score it against references.""" backlog = [] try: # Generate response using the server - response = await self.server.generate_chat_completion(item.messages) + response = await self.server.chat_completion(messages=item["messages"]) # Extract response content - response_content = response.get("content", "") + response_content = response.choices[0].message.content # Get references from item metadata - references = item.metadata.get("references", []) + references = item["metadata"].get("references", []) # Calculate score using the specified reward functions final_score = await self._calculate_reward(response_content, references) @@ -430,13 +443,18 @@ class BLEUBERIEnv(BaseEnv): tokens = self.tokenizer.encode(response_content) mask = [1] * len(tokens) - # Create scored data item - scored_data = { + # Create scored data item as ScoredDataItem + scored_data: ScoredDataItem = { "tokens": tokens, "masks": mask, "scores": final_score, - "messages": item.messages + "messages": item["messages"] + [{"role": "assistant", "content": response_content}], + "advantages": None, + "ref_logprobs": None, + "group_overrides": None, + "overrides": None, + "images": None, } return scored_data, backlog @@ -477,20 +495,20 @@ class BLEUBERIEnv(BaseEnv): messages.append({"role": "user", "content": user_prompt}) # Create item - item = Item( - messages=messages, - id=str(example.get("id", f"eval_{total_count}")), - metadata={ + item: BLEUBERIItem = { + "messages": messages, + "id": str(example.get("id", f"eval_{total_count}")), + "metadata": { "references": example.get("references", []), "source": example.get("source", "unknown"), "prompt": user_prompt, }, - ) + } # Generate response try: - response = await self.server.generate_chat_completion(item.messages) - response_content = response.get("content", "") + response = await self.server.chat_completion(messages=item["messages"]) + response_content = response.choices[0].message.content # Get references references = example.get("references", []) diff --git a/environments/bleuberi/requirements.txt b/environments/bleuberi/requirements.txt index 7d476b60..a35ae496 100644 --- a/environments/bleuberi/requirements.txt +++ b/environments/bleuberi/requirements.txt @@ -1,10 +1,12 @@ sacrebleu>=2.5.1 bert-score>=0.3.13 evaluate>=0.4.3 +rouge_score>=0.1.2 datasets>=3.3.2 pandas>=2.2.3 tqdm>=4.67.1 numpy>=1.26.4 +model2vec>=0.4.1 +absl-py>=0.9.0 # Optional dependencies for reference generation # vllm>=0.8.3 -# model2vec>=0.4.1