mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-02 17:45:50 +00:00
Fix BLEUBERI environment server integration
This commit is contained in:
parent
3109fe349b
commit
64a82c4b4f
3 changed files with 59 additions and 23 deletions
|
|
@ -10,13 +10,24 @@ https://arxiv.org/abs/2505.11080
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import wandb
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, Item
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataItem
|
||||
|
||||
|
||||
# Define our own Item class for the environment
|
||||
class BLEUBERIItem(TypedDict):
|
||||
"""Item for BLEUBERI environment"""
|
||||
|
||||
id: str
|
||||
messages: List[Dict[str, Any]]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
# Add the BLEUBERI repository to the Python path
|
||||
_SUBMODULE_DIR = os.path.abspath(
|
||||
|
|
@ -266,7 +277,7 @@ class BLEUBERIEnv(BaseEnv):
|
|||
|
||||
return examples
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
async def get_next_item(self) -> BLEUBERIItem:
|
||||
"""Get the next example from the dataset."""
|
||||
if not self.train_examples:
|
||||
self.logger.warning("No train examples available")
|
||||
|
|
@ -291,15 +302,15 @@ class BLEUBERIEnv(BaseEnv):
|
|||
messages.append({"role": "user", "content": user_prompt})
|
||||
|
||||
# Create item
|
||||
item = Item(
|
||||
messages=messages,
|
||||
id=str(example.get("id", f"item_{self.train_index}")),
|
||||
metadata={
|
||||
item: BLEUBERIItem = {
|
||||
"messages": messages,
|
||||
"id": str(example.get("id", f"item_{self.train_index}")),
|
||||
"metadata": {
|
||||
"references": example.get("references", []),
|
||||
"source": example.get("source", "unknown"),
|
||||
"prompt": user_prompt,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
return item
|
||||
|
||||
|
|
@ -404,19 +415,21 @@ class BLEUBERIEnv(BaseEnv):
|
|||
final_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
|
||||
return final_score
|
||||
|
||||
async def collect_trajectory(self, item: Item) -> Tuple[Dict, List[Item]]:
|
||||
async def collect_trajectory(
|
||||
self, item: BLEUBERIItem
|
||||
) -> Tuple[Optional[ScoredDataItem], List[BLEUBERIItem]]:
|
||||
"""Generate a response and score it against references."""
|
||||
backlog = []
|
||||
|
||||
try:
|
||||
# Generate response using the server
|
||||
response = await self.server.generate_chat_completion(item.messages)
|
||||
response = await self.server.chat_completion(messages=item["messages"])
|
||||
|
||||
# Extract response content
|
||||
response_content = response.get("content", "")
|
||||
response_content = response.choices[0].message.content
|
||||
|
||||
# Get references from item metadata
|
||||
references = item.metadata.get("references", [])
|
||||
references = item["metadata"].get("references", [])
|
||||
|
||||
# Calculate score using the specified reward functions
|
||||
final_score = await self._calculate_reward(response_content, references)
|
||||
|
|
@ -430,13 +443,18 @@ class BLEUBERIEnv(BaseEnv):
|
|||
tokens = self.tokenizer.encode(response_content)
|
||||
mask = [1] * len(tokens)
|
||||
|
||||
# Create scored data item
|
||||
scored_data = {
|
||||
# Create scored data item as ScoredDataItem
|
||||
scored_data: ScoredDataItem = {
|
||||
"tokens": tokens,
|
||||
"masks": mask,
|
||||
"scores": final_score,
|
||||
"messages": item.messages
|
||||
"messages": item["messages"]
|
||||
+ [{"role": "assistant", "content": response_content}],
|
||||
"advantages": None,
|
||||
"ref_logprobs": None,
|
||||
"group_overrides": None,
|
||||
"overrides": None,
|
||||
"images": None,
|
||||
}
|
||||
|
||||
return scored_data, backlog
|
||||
|
|
@ -477,20 +495,20 @@ class BLEUBERIEnv(BaseEnv):
|
|||
messages.append({"role": "user", "content": user_prompt})
|
||||
|
||||
# Create item
|
||||
item = Item(
|
||||
messages=messages,
|
||||
id=str(example.get("id", f"eval_{total_count}")),
|
||||
metadata={
|
||||
item: BLEUBERIItem = {
|
||||
"messages": messages,
|
||||
"id": str(example.get("id", f"eval_{total_count}")),
|
||||
"metadata": {
|
||||
"references": example.get("references", []),
|
||||
"source": example.get("source", "unknown"),
|
||||
"prompt": user_prompt,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
# Generate response
|
||||
try:
|
||||
response = await self.server.generate_chat_completion(item.messages)
|
||||
response_content = response.get("content", "")
|
||||
response = await self.server.chat_completion(messages=item["messages"])
|
||||
response_content = response.choices[0].message.content
|
||||
|
||||
# Get references
|
||||
references = example.get("references", [])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue