Fix BLEUBERI environment server integration

This commit is contained in:
Allan Niemerg 2025-06-09 10:28:00 -05:00
parent 3109fe349b
commit 64a82c4b4f
3 changed files with 59 additions and 23 deletions

View file

@ -10,13 +10,24 @@ https://arxiv.org/abs/2505.11080
import os
import random
import sys
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from pydantic import Field
from typing_extensions import TypedDict
import wandb
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, Item
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataItem
# Define our own Item class for the environment
class BLEUBERIItem(TypedDict):
"""Item for BLEUBERI environment"""
id: str
messages: List[Dict[str, Any]]
metadata: Dict[str, Any]
# Add the BLEUBERI repository to the Python path
_SUBMODULE_DIR = os.path.abspath(
@ -266,7 +277,7 @@ class BLEUBERIEnv(BaseEnv):
return examples
async def get_next_item(self) -> Item:
async def get_next_item(self) -> BLEUBERIItem:
"""Get the next example from the dataset."""
if not self.train_examples:
self.logger.warning("No train examples available")
@ -291,15 +302,15 @@ class BLEUBERIEnv(BaseEnv):
messages.append({"role": "user", "content": user_prompt})
# Create item
item = Item(
messages=messages,
id=str(example.get("id", f"item_{self.train_index}")),
metadata={
item: BLEUBERIItem = {
"messages": messages,
"id": str(example.get("id", f"item_{self.train_index}")),
"metadata": {
"references": example.get("references", []),
"source": example.get("source", "unknown"),
"prompt": user_prompt,
},
)
}
return item
@ -404,19 +415,21 @@ class BLEUBERIEnv(BaseEnv):
final_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
return final_score
async def collect_trajectory(self, item: Item) -> Tuple[Dict, List[Item]]:
async def collect_trajectory(
self, item: BLEUBERIItem
) -> Tuple[Optional[ScoredDataItem], List[BLEUBERIItem]]:
"""Generate a response and score it against references."""
backlog = []
try:
# Generate response using the server
response = await self.server.generate_chat_completion(item.messages)
response = await self.server.chat_completion(messages=item["messages"])
# Extract response content
response_content = response.get("content", "")
response_content = response.choices[0].message.content
# Get references from item metadata
references = item.metadata.get("references", [])
references = item["metadata"].get("references", [])
# Calculate score using the specified reward functions
final_score = await self._calculate_reward(response_content, references)
@ -430,13 +443,18 @@ class BLEUBERIEnv(BaseEnv):
tokens = self.tokenizer.encode(response_content)
mask = [1] * len(tokens)
# Create scored data item
scored_data = {
# Create scored data item as ScoredDataItem
scored_data: ScoredDataItem = {
"tokens": tokens,
"masks": mask,
"scores": final_score,
"messages": item.messages
"messages": item["messages"]
+ [{"role": "assistant", "content": response_content}],
"advantages": None,
"ref_logprobs": None,
"group_overrides": None,
"overrides": None,
"images": None,
}
return scored_data, backlog
@ -477,20 +495,20 @@ class BLEUBERIEnv(BaseEnv):
messages.append({"role": "user", "content": user_prompt})
# Create item
item = Item(
messages=messages,
id=str(example.get("id", f"eval_{total_count}")),
metadata={
item: BLEUBERIItem = {
"messages": messages,
"id": str(example.get("id", f"eval_{total_count}")),
"metadata": {
"references": example.get("references", []),
"source": example.get("source", "unknown"),
"prompt": user_prompt,
},
)
}
# Generate response
try:
response = await self.server.generate_chat_completion(item.messages)
response_content = response.get("content", "")
response = await self.server.chat_completion(messages=item["messages"])
response_content = response.choices[0].message.content
# Get references
references = example.get("references", [])