mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-02 17:45:50 +00:00
remove unnecessary code, change log level
This commit is contained in:
parent
1a2551c812
commit
532024d01e
1 changed files with 48 additions and 195 deletions
|
|
@ -20,12 +20,12 @@ from pydantic import Field
|
|||
from typing_extensions import TypedDict
|
||||
|
||||
import wandb
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataItem, ScoredDataGroup
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup, ScoredDataItem
|
||||
from atroposlib.envs.server_handling.openai_server import APIServerConfig
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
level=logging.WARNING, # Changed from INFO to WARNING to reduce verbosity
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
|
@ -33,16 +33,6 @@ logging.basicConfig(
|
|||
# Load environment variables from .env file if available
|
||||
load_dotenv()
|
||||
|
||||
# Check for OpenAI API key
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print(
|
||||
"WARNING: OPENAI_API_KEY environment variable not found. Make sure to set it!"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Found OPENAI_API_KEY environment variable ({os.environ.get('OPENAI_API_KEY')[:5]}...)"
|
||||
)
|
||||
|
||||
|
||||
# Define our own Item class for the environment
|
||||
class BLEUBERIItem(TypedDict):
|
||||
|
|
@ -163,46 +153,38 @@ class BLEUBERIEnv(BaseEnv):
|
|||
"""Initialize configuration with OpenAI API settings."""
|
||||
# Load API key from environment
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
print("WARNING: OPENAI_API_KEY environment variable not found!")
|
||||
print(
|
||||
"Please set the OPENAI_API_KEY environment variable or add it to a .env file"
|
||||
)
|
||||
|
||||
# Create environment config with all necessary settings
|
||||
env_config = BLEUBERIEnvConfig(
|
||||
tokenizer_name="gpt2",
|
||||
group_size=2, # Reduced from 4 to minimize API calls
|
||||
group_size=2,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=2, # Minimal number of steps for a quick test
|
||||
total_steps=2,
|
||||
batch_size=-1,
|
||||
steps_per_eval=1, # Evaluate after each step for quick testing
|
||||
steps_per_eval=1,
|
||||
max_token_length=2048,
|
||||
wandb_name="bleuberi",
|
||||
dataset_name="allenai/tulu-3-sft-mixture",
|
||||
dataset_split="train",
|
||||
reward_funcs=["bleu"],
|
||||
ref_models=["gold"],
|
||||
# Example limiting for quick testing (remove or set to None for full dataset)
|
||||
max_train_examples=2, # Limit to just 2 training examples for minimal testing
|
||||
max_test_examples=1, # Limit to 1 test example for minimal testing
|
||||
# Parallelism configuration (adjust for your use case)
|
||||
max_num_workers=2, # Limit number of workers for training
|
||||
max_eval_workers=1, # Limit number of workers for evaluation
|
||||
# Optional: Add a place to save the data
|
||||
max_train_examples=2,
|
||||
max_test_examples=1,
|
||||
max_num_workers=2,
|
||||
max_eval_workers=1,
|
||||
data_path_to_save_groups="bleuberi_openai_test.jsonl",
|
||||
)
|
||||
|
||||
# Create OpenAI server config
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt-4.1-nano", # Or your preferred model
|
||||
model_name="gpt-4.1-nano",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key=api_key,
|
||||
timeout=60,
|
||||
num_max_requests_at_once=4, # Increased from 1 to allow parallel requests
|
||||
num_requests_for_eval=4, # Increased from 1 to allow parallel evaluation
|
||||
num_max_requests_at_once=4,
|
||||
num_requests_for_eval=4,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -225,28 +207,13 @@ class BLEUBERIEnv(BaseEnv):
|
|||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
self.logger.warning("OPENAI_API_KEY environment variable not found!")
|
||||
self.logger.warning(
|
||||
"Please set the OPENAI_API_KEY environment variable or add it to a .env file"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Found OPENAI_API_KEY in environment variables: {api_key[:5]}..."
|
||||
)
|
||||
|
||||
# Update server configs with API key if needed
|
||||
for server in server_configs:
|
||||
if getattr(server, "server_type", "") == "openai" and not getattr(
|
||||
server, "api_key", None
|
||||
):
|
||||
server.api_key = api_key
|
||||
self.logger.info(
|
||||
f"Updated server config with API key: {api_key[:5]}..."
|
||||
)
|
||||
|
||||
# Print minimal server configuration info
|
||||
for i, server in enumerate(server_configs):
|
||||
if hasattr(server, "model_name"):
|
||||
self.logger.info(f"Server {i+1} using model: {server.model_name}")
|
||||
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config = config
|
||||
|
|
@ -257,18 +224,10 @@ class BLEUBERIEnv(BaseEnv):
|
|||
self.test_examples = None
|
||||
self.train_index = 0
|
||||
|
||||
# Minimal server initialization message
|
||||
if hasattr(self, "server"):
|
||||
self.logger.info(f"Server initialized with {len(getattr(self.server, 'servers', []))} instances")
|
||||
else:
|
||||
self.logger.warning("No 'server' attribute found after initialization!")
|
||||
|
||||
# Track training metrics
|
||||
self.percent_correct_buffer = []
|
||||
self.token_lengths_buffer = []
|
||||
self.bleu_scores_buffer = []
|
||||
self.rouge_scores_buffer = []
|
||||
self.bertscore_buffer = []
|
||||
self.category_performance = {} # Track performance by category
|
||||
|
||||
# Store rollouts for wandb visualization
|
||||
|
|
@ -298,39 +257,19 @@ class BLEUBERIEnv(BaseEnv):
|
|||
self.logger.info("Setting up BLEUBERI environment")
|
||||
|
||||
# Load dataset
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset
|
||||
|
||||
self.dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
split=self.config.dataset_split,
|
||||
cache_dir=self.config.cache_dir,
|
||||
streaming=self.config.streaming,
|
||||
)
|
||||
self.dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
split=self.config.dataset_split,
|
||||
cache_dir=self.config.cache_dir,
|
||||
streaming=self.config.streaming,
|
||||
)
|
||||
|
||||
if self.config.shuffle and not self.config.streaming:
|
||||
self.dataset = self.dataset.shuffle(seed=self.config.seed)
|
||||
if self.config.shuffle and not self.config.streaming:
|
||||
self.dataset = self.dataset.shuffle(seed=self.config.seed)
|
||||
|
||||
self.logger.info(f"Loaded dataset with {len(self.dataset)} examples")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading dataset: {e}")
|
||||
# Create a small dummy dataset for testing
|
||||
from datasets import Dataset
|
||||
|
||||
dummy_data = []
|
||||
for i in range(10):
|
||||
dummy_data.append(
|
||||
{
|
||||
"id": i,
|
||||
"messages": [
|
||||
{"role": "user", "content": f"Sample prompt {i}"},
|
||||
{"role": "assistant", "content": f"Sample response {i}"},
|
||||
],
|
||||
"source": "dummy",
|
||||
}
|
||||
)
|
||||
self.dataset = Dataset.from_list(dummy_data)
|
||||
self.logger.info(f"Created dummy dataset with {len(self.dataset)} examples")
|
||||
self.logger.info(f"Loaded dataset with {len(self.dataset)} examples")
|
||||
|
||||
# Split into train and test (98% train, 2% test)
|
||||
train_size = int(0.98 * len(self.dataset))
|
||||
|
|
@ -467,36 +406,8 @@ class BLEUBERIEnv(BaseEnv):
|
|||
self, response_content: str, references: List[str]
|
||||
) -> float:
|
||||
"""Calculate BLEU score for a response against references using BLEUBERI implementation."""
|
||||
dataset = KeywordDataset("", self.tokenizer)
|
||||
kwargs = {"references": references}
|
||||
scores = dataset.bleu_reward_func([response_content], **kwargs)
|
||||
return scores[0] if scores else 0.0
|
||||
|
||||
async def _calculate_rouge_score(
|
||||
self, response_content: str, references: List[str]
|
||||
) -> float:
|
||||
"""Calculate ROUGE score for a response against references using BLEUBERI implementation."""
|
||||
dataset = KeywordDataset("", self.tokenizer)
|
||||
kwargs = {"references": references}
|
||||
scores = dataset.rouge_reward_func([response_content], **kwargs)
|
||||
return scores[0] if scores else 0.0
|
||||
|
||||
async def _calculate_bertscore(
|
||||
self, response_content: str, references: List[str]
|
||||
) -> float:
|
||||
"""Calculate BERTScore for a response against references using BLEUBERI implementation."""
|
||||
dataset = KeywordDataset("", self.tokenizer)
|
||||
kwargs = {"references": references}
|
||||
scores = dataset.bertscore_reward_func([response_content], **kwargs)
|
||||
return scores[0] if scores else 0.0
|
||||
|
||||
async def _calculate_bleu_rouge_f1(
|
||||
self, response_content: str, references: List[str]
|
||||
) -> float:
|
||||
"""Calculate F1 of BLEU and ROUGE scores using BLEUBERI implementation."""
|
||||
dataset = KeywordDataset("", self.tokenizer)
|
||||
kwargs = {"references": references}
|
||||
scores = dataset.bleu_rouge_f1_reward_func([response_content], **kwargs)
|
||||
scores = self.bleuberi_dataset.bleu_reward_func([response_content], **kwargs)
|
||||
return scores[0] if scores else 0.0
|
||||
|
||||
async def _calculate_reward(
|
||||
|
|
@ -538,20 +449,10 @@ class BLEUBERIEnv(BaseEnv):
|
|||
|
||||
async def cleanup(self):
|
||||
"""
|
||||
Cleanup the environment by cancelling health check tasks
|
||||
Cleanup the environment
|
||||
"""
|
||||
if hasattr(self, "server"):
|
||||
if hasattr(self.server, "servers"):
|
||||
for i, server in enumerate(self.server.servers):
|
||||
if hasattr(server, "check_task") and server.check_task:
|
||||
server.check_task.cancel()
|
||||
try:
|
||||
await server.check_task
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"Error while cancelling health check task: {e}"
|
||||
)
|
||||
server.check_task = None
|
||||
# Let the parent class handle cleanup
|
||||
await super().cleanup()
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: BLEUBERIItem
|
||||
|
|
@ -571,12 +472,8 @@ class BLEUBERIEnv(BaseEnv):
|
|||
prompt = item["metadata"].get("prompt", "")
|
||||
source_category = item["metadata"].get("source", "unknown")
|
||||
|
||||
# Calculate individual reward metrics
|
||||
# Calculate reward metrics
|
||||
bleu_score = await self._calculate_bleu_score(response_content, references)
|
||||
rouge_score = await self._calculate_rouge_score(
|
||||
response_content, references
|
||||
)
|
||||
bertscore = await self._calculate_bertscore(response_content, references)
|
||||
|
||||
# Calculate final score using the specified reward functions
|
||||
final_score = await self._calculate_reward(response_content, references)
|
||||
|
|
@ -587,16 +484,12 @@ class BLEUBERIEnv(BaseEnv):
|
|||
len(self.tokenizer.encode(response_content))
|
||||
)
|
||||
self.bleu_scores_buffer.append(bleu_score)
|
||||
self.rouge_scores_buffer.append(rouge_score)
|
||||
self.bertscore_buffer.append(bertscore)
|
||||
|
||||
# Maintain buffer size
|
||||
if len(self.percent_correct_buffer) > 100:
|
||||
self.percent_correct_buffer.pop(0)
|
||||
self.token_lengths_buffer.pop(0)
|
||||
self.bleu_scores_buffer.pop(0)
|
||||
self.rouge_scores_buffer.pop(0)
|
||||
self.bertscore_buffer.pop(0)
|
||||
|
||||
# Track performance by category
|
||||
if source_category not in self.category_performance:
|
||||
|
|
@ -615,8 +508,6 @@ class BLEUBERIEnv(BaseEnv):
|
|||
"response": response_content,
|
||||
"references": references,
|
||||
"bleu_score": bleu_score,
|
||||
"rouge_score": rouge_score,
|
||||
"bertscore": bertscore,
|
||||
"final_score": final_score,
|
||||
"category": source_category,
|
||||
"is_correct": final_score > 0.5,
|
||||
|
|
@ -648,9 +539,6 @@ class BLEUBERIEnv(BaseEnv):
|
|||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in collect_trajectory: {e}")
|
||||
import traceback
|
||||
|
||||
self.logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return None, backlog
|
||||
|
||||
async def collect_trajectories(self, item: BLEUBERIItem) -> Tuple[
|
||||
|
|
@ -668,14 +556,14 @@ class BLEUBERIEnv(BaseEnv):
|
|||
for _ in range(self.config.group_size):
|
||||
tasks.append(self.collect_trajectory(item))
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
if any(not isinstance(result[0], dict) for result in results):
|
||||
logging.error("something wasn't a ScoredDataItem")
|
||||
raise ValueError(
|
||||
"collect_trajectory must return a ScoredDataItem or None to use the default "
|
||||
"collect_trajectories method"
|
||||
)
|
||||
|
||||
|
||||
backlog = []
|
||||
to_postprocess = ScoredDataGroup()
|
||||
to_postprocess["tokens"] = []
|
||||
|
|
@ -687,13 +575,13 @@ class BLEUBERIEnv(BaseEnv):
|
|||
to_postprocess["group_overrides"] = {}
|
||||
to_postprocess["overrides"] = []
|
||||
to_postprocess["images"] = []
|
||||
|
||||
|
||||
self.logger.info("Processing results for BLEUBERI trajectories")
|
||||
for result in results:
|
||||
to_postprocess["tokens"].append(result[0]["tokens"])
|
||||
to_postprocess["masks"].append(result[0]["masks"])
|
||||
to_postprocess["scores"].append(result[0]["scores"])
|
||||
|
||||
|
||||
if result[0].get("advantages", None) is not None:
|
||||
to_postprocess["advantages"].append(result[0]["advantages"])
|
||||
if result[0].get("ref_logprobs", None) is not None:
|
||||
|
|
@ -706,37 +594,42 @@ class BLEUBERIEnv(BaseEnv):
|
|||
to_postprocess["overrides"].append(result[0]["overrides"])
|
||||
if result[0].get("images", None) is not None:
|
||||
to_postprocess["images"].append(result[0]["images"])
|
||||
|
||||
|
||||
backlog.extend(result[1])
|
||||
|
||||
|
||||
# Process the data for HTML compatibility before sending to the API
|
||||
# Convert nested message structure to flat strings for HTML rendering
|
||||
if "messages" in to_postprocess and to_postprocess["messages"]:
|
||||
# Extract the assistant message content from each result
|
||||
html_compatible_messages = []
|
||||
|
||||
|
||||
for result in results:
|
||||
if "messages" in result[0] and result[0]["messages"]:
|
||||
# Find the LAST assistant message (most recent response)
|
||||
assistant_messages = [
|
||||
msg for msg in result[0]["messages"]
|
||||
msg
|
||||
for msg in result[0]["messages"]
|
||||
if msg.get("role") == "assistant"
|
||||
]
|
||||
|
||||
|
||||
if assistant_messages:
|
||||
# Get just the content of the last assistant message
|
||||
last_assistant_msg = assistant_messages[-1]
|
||||
html_compatible_messages.append(last_assistant_msg.get("content", ""))
|
||||
|
||||
html_compatible_messages.append(
|
||||
last_assistant_msg.get("content", "")
|
||||
)
|
||||
|
||||
# Replace the nested messages with flat strings
|
||||
if html_compatible_messages:
|
||||
to_postprocess["messages"] = html_compatible_messages
|
||||
self.logger.info(f"Prepared HTML-compatible format with {len(html_compatible_messages)} messages")
|
||||
|
||||
self.logger.info(
|
||||
f"Prepared HTML-compatible format with {len(html_compatible_messages)} messages"
|
||||
)
|
||||
|
||||
# The parent's handle_send_to_api method will write this to JSONL
|
||||
|
||||
|
||||
return to_postprocess, backlog
|
||||
|
||||
|
||||
async def evaluate(self):
|
||||
"""Evaluate the model on the test set."""
|
||||
self.logger.info("Starting evaluation")
|
||||
|
|
@ -749,8 +642,6 @@ class BLEUBERIEnv(BaseEnv):
|
|||
correct_count = 0
|
||||
total_count = 0
|
||||
all_bleu_scores = []
|
||||
all_rouge_scores = []
|
||||
all_bertscore_scores = []
|
||||
all_final_scores = []
|
||||
category_results = {}
|
||||
token_lengths = []
|
||||
|
|
@ -762,8 +653,6 @@ class BLEUBERIEnv(BaseEnv):
|
|||
"response",
|
||||
"reference",
|
||||
"bleu_score",
|
||||
"rouge_score",
|
||||
"bertscore",
|
||||
"final_score",
|
||||
"category",
|
||||
"is_correct",
|
||||
|
|
@ -819,16 +708,10 @@ class BLEUBERIEnv(BaseEnv):
|
|||
references = example.get("references", [])
|
||||
reference_text = references[0] if references else "No reference"
|
||||
|
||||
# Calculate individual metrics
|
||||
# Calculate metrics
|
||||
bleu_score = await self._calculate_bleu_score(
|
||||
response_content, references
|
||||
)
|
||||
rouge_score = await self._calculate_rouge_score(
|
||||
response_content, references
|
||||
)
|
||||
bertscore = await self._calculate_bertscore(
|
||||
response_content, references
|
||||
)
|
||||
|
||||
# Calculate final score
|
||||
final_score = await self._calculate_reward(response_content, references)
|
||||
|
|
@ -839,8 +722,6 @@ class BLEUBERIEnv(BaseEnv):
|
|||
|
||||
# Track scores
|
||||
all_bleu_scores.append(bleu_score)
|
||||
all_rouge_scores.append(rouge_score)
|
||||
all_bertscore_scores.append(bertscore)
|
||||
all_final_scores.append(final_score)
|
||||
|
||||
# Count as correct if score > 0.5
|
||||
|
|
@ -860,8 +741,6 @@ class BLEUBERIEnv(BaseEnv):
|
|||
response_content,
|
||||
reference_text,
|
||||
bleu_score,
|
||||
rouge_score,
|
||||
bertscore,
|
||||
final_score,
|
||||
source_category,
|
||||
is_correct,
|
||||
|
|
@ -872,9 +751,6 @@ class BLEUBERIEnv(BaseEnv):
|
|||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in evaluation: {e}")
|
||||
import traceback
|
||||
|
||||
self.logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
# Calculate evaluation metrics
|
||||
accuracy = correct_count / total_count if total_count > 0 else 0
|
||||
|
|
@ -902,14 +778,6 @@ class BLEUBERIEnv(BaseEnv):
|
|||
"eval/avg_bleu": (
|
||||
sum(all_bleu_scores) / len(all_bleu_scores) if all_bleu_scores else 0
|
||||
),
|
||||
"eval/avg_rouge": (
|
||||
sum(all_rouge_scores) / len(all_rouge_scores) if all_rouge_scores else 0
|
||||
),
|
||||
"eval/avg_bertscore": (
|
||||
sum(all_bertscore_scores) / len(all_bertscore_scores)
|
||||
if all_bertscore_scores
|
||||
else 0
|
||||
),
|
||||
"eval/avg_final_score": (
|
||||
sum(all_final_scores) / len(all_final_scores) if all_final_scores else 0
|
||||
),
|
||||
|
|
@ -1030,27 +898,12 @@ class BLEUBERIEnv(BaseEnv):
|
|||
self.bleu_scores_buffer
|
||||
)
|
||||
|
||||
if self.rouge_scores_buffer:
|
||||
wandb_metrics["train/avg_rouge"] = sum(self.rouge_scores_buffer) / len(
|
||||
self.rouge_scores_buffer
|
||||
)
|
||||
|
||||
if self.bertscore_buffer:
|
||||
wandb_metrics["train/avg_bertscore"] = sum(self.bertscore_buffer) / len(
|
||||
self.bertscore_buffer
|
||||
)
|
||||
|
||||
# Create histograms for score distributions
|
||||
if len(self.bleu_scores_buffer) > 10:
|
||||
wandb_metrics["train/bleu_distribution"] = wandb.Histogram(
|
||||
self.bleu_scores_buffer
|
||||
)
|
||||
|
||||
if len(self.rouge_scores_buffer) > 10:
|
||||
wandb_metrics["train/rouge_distribution"] = wandb.Histogram(
|
||||
self.rouge_scores_buffer
|
||||
)
|
||||
|
||||
# Add rollout table and category performance
|
||||
wandb_metrics = await self.create_rollout_table(wandb_metrics)
|
||||
wandb_metrics = await self.create_category_performance_table(wandb_metrics)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue