remove unnecessary code, change log level

This commit is contained in:
Allan Niemerg 2025-06-10 12:10:58 -05:00
parent 1a2551c812
commit 532024d01e

View file

@ -20,12 +20,12 @@ from pydantic import Field
from typing_extensions import TypedDict
import wandb
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataItem, ScoredDataGroup
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup, ScoredDataItem
from atroposlib.envs.server_handling.openai_server import APIServerConfig
# Configure logging
logging.basicConfig(
level=logging.INFO,
level=logging.WARNING, # Changed from INFO to WARNING to reduce verbosity
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
@ -33,16 +33,6 @@ logging.basicConfig(
# Load environment variables from .env file if available
load_dotenv()
# Check for OpenAI API key
if not os.environ.get("OPENAI_API_KEY"):
print(
"WARNING: OPENAI_API_KEY environment variable not found. Make sure to set it!"
)
else:
print(
f"Found OPENAI_API_KEY environment variable ({os.environ.get('OPENAI_API_KEY')[:5]}...)"
)
# Define our own Item class for the environment
class BLEUBERIItem(TypedDict):
@ -163,46 +153,38 @@ class BLEUBERIEnv(BaseEnv):
"""Initialize configuration with OpenAI API settings."""
# Load API key from environment
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
print("WARNING: OPENAI_API_KEY environment variable not found!")
print(
"Please set the OPENAI_API_KEY environment variable or add it to a .env file"
)
# Create environment config with all necessary settings
env_config = BLEUBERIEnvConfig(
tokenizer_name="gpt2",
group_size=2, # Reduced from 4 to minimize API calls
group_size=2,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=2, # Minimal number of steps for a quick test
total_steps=2,
batch_size=-1,
steps_per_eval=1, # Evaluate after each step for quick testing
steps_per_eval=1,
max_token_length=2048,
wandb_name="bleuberi",
dataset_name="allenai/tulu-3-sft-mixture",
dataset_split="train",
reward_funcs=["bleu"],
ref_models=["gold"],
# Example limiting for quick testing (remove or set to None for full dataset)
max_train_examples=2, # Limit to just 2 training examples for minimal testing
max_test_examples=1, # Limit to 1 test example for minimal testing
# Parallelism configuration (adjust for your use case)
max_num_workers=2, # Limit number of workers for training
max_eval_workers=1, # Limit number of workers for evaluation
# Optional: Add a place to save the data
max_train_examples=2,
max_test_examples=1,
max_num_workers=2,
max_eval_workers=1,
data_path_to_save_groups="bleuberi_openai_test.jsonl",
)
# Create OpenAI server config
server_configs = [
APIServerConfig(
model_name="gpt-4.1-nano", # Or your preferred model
model_name="gpt-4.1-nano",
base_url="https://api.openai.com/v1",
api_key=api_key,
timeout=60,
num_max_requests_at_once=4, # Increased from 1 to allow parallel requests
num_requests_for_eval=4, # Increased from 1 to allow parallel evaluation
num_max_requests_at_once=4,
num_requests_for_eval=4,
),
]
@ -225,28 +207,13 @@ class BLEUBERIEnv(BaseEnv):
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
self.logger.warning("OPENAI_API_KEY environment variable not found!")
self.logger.warning(
"Please set the OPENAI_API_KEY environment variable or add it to a .env file"
)
else:
self.logger.info(
f"Found OPENAI_API_KEY in environment variables: {api_key[:5]}..."
)
# Update server configs with API key if needed
for server in server_configs:
if getattr(server, "server_type", "") == "openai" and not getattr(
server, "api_key", None
):
server.api_key = api_key
self.logger.info(
f"Updated server config with API key: {api_key[:5]}..."
)
# Print minimal server configuration info
for i, server in enumerate(server_configs):
if hasattr(server, "model_name"):
self.logger.info(f"Server {i+1} using model: {server.model_name}")
super().__init__(config, server_configs, slurm, testing)
self.config = config
@ -257,18 +224,10 @@ class BLEUBERIEnv(BaseEnv):
self.test_examples = None
self.train_index = 0
# Minimal server initialization message
if hasattr(self, "server"):
self.logger.info(f"Server initialized with {len(getattr(self.server, 'servers', []))} instances")
else:
self.logger.warning("No 'server' attribute found after initialization!")
# Track training metrics
self.percent_correct_buffer = []
self.token_lengths_buffer = []
self.bleu_scores_buffer = []
self.rouge_scores_buffer = []
self.bertscore_buffer = []
self.category_performance = {} # Track performance by category
# Store rollouts for wandb visualization
@ -298,39 +257,19 @@ class BLEUBERIEnv(BaseEnv):
self.logger.info("Setting up BLEUBERI environment")
# Load dataset
try:
from datasets import load_dataset
from datasets import load_dataset
self.dataset = load_dataset(
self.config.dataset_name,
split=self.config.dataset_split,
cache_dir=self.config.cache_dir,
streaming=self.config.streaming,
)
self.dataset = load_dataset(
self.config.dataset_name,
split=self.config.dataset_split,
cache_dir=self.config.cache_dir,
streaming=self.config.streaming,
)
if self.config.shuffle and not self.config.streaming:
self.dataset = self.dataset.shuffle(seed=self.config.seed)
if self.config.shuffle and not self.config.streaming:
self.dataset = self.dataset.shuffle(seed=self.config.seed)
self.logger.info(f"Loaded dataset with {len(self.dataset)} examples")
except Exception as e:
self.logger.error(f"Error loading dataset: {e}")
# Create a small dummy dataset for testing
from datasets import Dataset
dummy_data = []
for i in range(10):
dummy_data.append(
{
"id": i,
"messages": [
{"role": "user", "content": f"Sample prompt {i}"},
{"role": "assistant", "content": f"Sample response {i}"},
],
"source": "dummy",
}
)
self.dataset = Dataset.from_list(dummy_data)
self.logger.info(f"Created dummy dataset with {len(self.dataset)} examples")
self.logger.info(f"Loaded dataset with {len(self.dataset)} examples")
# Split into train and test (98% train, 2% test)
train_size = int(0.98 * len(self.dataset))
@ -467,36 +406,8 @@ class BLEUBERIEnv(BaseEnv):
self, response_content: str, references: List[str]
) -> float:
"""Calculate BLEU score for a response against references using BLEUBERI implementation."""
dataset = KeywordDataset("", self.tokenizer)
kwargs = {"references": references}
scores = dataset.bleu_reward_func([response_content], **kwargs)
return scores[0] if scores else 0.0
async def _calculate_rouge_score(
self, response_content: str, references: List[str]
) -> float:
"""Calculate ROUGE score for a response against references using BLEUBERI implementation."""
dataset = KeywordDataset("", self.tokenizer)
kwargs = {"references": references}
scores = dataset.rouge_reward_func([response_content], **kwargs)
return scores[0] if scores else 0.0
async def _calculate_bertscore(
self, response_content: str, references: List[str]
) -> float:
"""Calculate BERTScore for a response against references using BLEUBERI implementation."""
dataset = KeywordDataset("", self.tokenizer)
kwargs = {"references": references}
scores = dataset.bertscore_reward_func([response_content], **kwargs)
return scores[0] if scores else 0.0
async def _calculate_bleu_rouge_f1(
self, response_content: str, references: List[str]
) -> float:
"""Calculate F1 of BLEU and ROUGE scores using BLEUBERI implementation."""
dataset = KeywordDataset("", self.tokenizer)
kwargs = {"references": references}
scores = dataset.bleu_rouge_f1_reward_func([response_content], **kwargs)
scores = self.bleuberi_dataset.bleu_reward_func([response_content], **kwargs)
return scores[0] if scores else 0.0
async def _calculate_reward(
@ -538,20 +449,10 @@ class BLEUBERIEnv(BaseEnv):
async def cleanup(self):
"""
Cleanup the environment by cancelling health check tasks
Cleanup the environment
"""
if hasattr(self, "server"):
if hasattr(self.server, "servers"):
for i, server in enumerate(self.server.servers):
if hasattr(server, "check_task") and server.check_task:
server.check_task.cancel()
try:
await server.check_task
except Exception as e:
self.logger.warning(
f"Error while cancelling health check task: {e}"
)
server.check_task = None
# Let the parent class handle cleanup
await super().cleanup()
async def collect_trajectory(
self, item: BLEUBERIItem
@ -571,12 +472,8 @@ class BLEUBERIEnv(BaseEnv):
prompt = item["metadata"].get("prompt", "")
source_category = item["metadata"].get("source", "unknown")
# Calculate individual reward metrics
# Calculate reward metrics
bleu_score = await self._calculate_bleu_score(response_content, references)
rouge_score = await self._calculate_rouge_score(
response_content, references
)
bertscore = await self._calculate_bertscore(response_content, references)
# Calculate final score using the specified reward functions
final_score = await self._calculate_reward(response_content, references)
@ -587,16 +484,12 @@ class BLEUBERIEnv(BaseEnv):
len(self.tokenizer.encode(response_content))
)
self.bleu_scores_buffer.append(bleu_score)
self.rouge_scores_buffer.append(rouge_score)
self.bertscore_buffer.append(bertscore)
# Maintain buffer size
if len(self.percent_correct_buffer) > 100:
self.percent_correct_buffer.pop(0)
self.token_lengths_buffer.pop(0)
self.bleu_scores_buffer.pop(0)
self.rouge_scores_buffer.pop(0)
self.bertscore_buffer.pop(0)
# Track performance by category
if source_category not in self.category_performance:
@ -615,8 +508,6 @@ class BLEUBERIEnv(BaseEnv):
"response": response_content,
"references": references,
"bleu_score": bleu_score,
"rouge_score": rouge_score,
"bertscore": bertscore,
"final_score": final_score,
"category": source_category,
"is_correct": final_score > 0.5,
@ -648,9 +539,6 @@ class BLEUBERIEnv(BaseEnv):
except Exception as e:
self.logger.error(f"Error in collect_trajectory: {e}")
import traceback
self.logger.error(f"Traceback: {traceback.format_exc()}")
return None, backlog
async def collect_trajectories(self, item: BLEUBERIItem) -> Tuple[
@ -668,14 +556,14 @@ class BLEUBERIEnv(BaseEnv):
for _ in range(self.config.group_size):
tasks.append(self.collect_trajectory(item))
results = await asyncio.gather(*tasks)
if any(not isinstance(result[0], dict) for result in results):
logging.error("something wasn't a ScoredDataItem")
raise ValueError(
"collect_trajectory must return a ScoredDataItem or None to use the default "
"collect_trajectories method"
)
backlog = []
to_postprocess = ScoredDataGroup()
to_postprocess["tokens"] = []
@ -687,13 +575,13 @@ class BLEUBERIEnv(BaseEnv):
to_postprocess["group_overrides"] = {}
to_postprocess["overrides"] = []
to_postprocess["images"] = []
self.logger.info("Processing results for BLEUBERI trajectories")
for result in results:
to_postprocess["tokens"].append(result[0]["tokens"])
to_postprocess["masks"].append(result[0]["masks"])
to_postprocess["scores"].append(result[0]["scores"])
if result[0].get("advantages", None) is not None:
to_postprocess["advantages"].append(result[0]["advantages"])
if result[0].get("ref_logprobs", None) is not None:
@ -706,37 +594,42 @@ class BLEUBERIEnv(BaseEnv):
to_postprocess["overrides"].append(result[0]["overrides"])
if result[0].get("images", None) is not None:
to_postprocess["images"].append(result[0]["images"])
backlog.extend(result[1])
# Process the data for HTML compatibility before sending to the API
# Convert nested message structure to flat strings for HTML rendering
if "messages" in to_postprocess and to_postprocess["messages"]:
# Extract the assistant message content from each result
html_compatible_messages = []
for result in results:
if "messages" in result[0] and result[0]["messages"]:
# Find the LAST assistant message (most recent response)
assistant_messages = [
msg for msg in result[0]["messages"]
msg
for msg in result[0]["messages"]
if msg.get("role") == "assistant"
]
if assistant_messages:
# Get just the content of the last assistant message
last_assistant_msg = assistant_messages[-1]
html_compatible_messages.append(last_assistant_msg.get("content", ""))
html_compatible_messages.append(
last_assistant_msg.get("content", "")
)
# Replace the nested messages with flat strings
if html_compatible_messages:
to_postprocess["messages"] = html_compatible_messages
self.logger.info(f"Prepared HTML-compatible format with {len(html_compatible_messages)} messages")
self.logger.info(
f"Prepared HTML-compatible format with {len(html_compatible_messages)} messages"
)
# The parent's handle_send_to_api method will write this to JSONL
return to_postprocess, backlog
async def evaluate(self):
"""Evaluate the model on the test set."""
self.logger.info("Starting evaluation")
@ -749,8 +642,6 @@ class BLEUBERIEnv(BaseEnv):
correct_count = 0
total_count = 0
all_bleu_scores = []
all_rouge_scores = []
all_bertscore_scores = []
all_final_scores = []
category_results = {}
token_lengths = []
@ -762,8 +653,6 @@ class BLEUBERIEnv(BaseEnv):
"response",
"reference",
"bleu_score",
"rouge_score",
"bertscore",
"final_score",
"category",
"is_correct",
@ -819,16 +708,10 @@ class BLEUBERIEnv(BaseEnv):
references = example.get("references", [])
reference_text = references[0] if references else "No reference"
# Calculate individual metrics
# Calculate metrics
bleu_score = await self._calculate_bleu_score(
response_content, references
)
rouge_score = await self._calculate_rouge_score(
response_content, references
)
bertscore = await self._calculate_bertscore(
response_content, references
)
# Calculate final score
final_score = await self._calculate_reward(response_content, references)
@ -839,8 +722,6 @@ class BLEUBERIEnv(BaseEnv):
# Track scores
all_bleu_scores.append(bleu_score)
all_rouge_scores.append(rouge_score)
all_bertscore_scores.append(bertscore)
all_final_scores.append(final_score)
# Count as correct if score > 0.5
@ -860,8 +741,6 @@ class BLEUBERIEnv(BaseEnv):
response_content,
reference_text,
bleu_score,
rouge_score,
bertscore,
final_score,
source_category,
is_correct,
@ -872,9 +751,6 @@ class BLEUBERIEnv(BaseEnv):
except Exception as e:
self.logger.error(f"Error in evaluation: {e}")
import traceback
self.logger.error(f"Traceback: {traceback.format_exc()}")
# Calculate evaluation metrics
accuracy = correct_count / total_count if total_count > 0 else 0
@ -902,14 +778,6 @@ class BLEUBERIEnv(BaseEnv):
"eval/avg_bleu": (
sum(all_bleu_scores) / len(all_bleu_scores) if all_bleu_scores else 0
),
"eval/avg_rouge": (
sum(all_rouge_scores) / len(all_rouge_scores) if all_rouge_scores else 0
),
"eval/avg_bertscore": (
sum(all_bertscore_scores) / len(all_bertscore_scores)
if all_bertscore_scores
else 0
),
"eval/avg_final_score": (
sum(all_final_scores) / len(all_final_scores) if all_final_scores else 0
),
@ -1030,27 +898,12 @@ class BLEUBERIEnv(BaseEnv):
self.bleu_scores_buffer
)
if self.rouge_scores_buffer:
wandb_metrics["train/avg_rouge"] = sum(self.rouge_scores_buffer) / len(
self.rouge_scores_buffer
)
if self.bertscore_buffer:
wandb_metrics["train/avg_bertscore"] = sum(self.bertscore_buffer) / len(
self.bertscore_buffer
)
# Create histograms for score distributions
if len(self.bleu_scores_buffer) > 10:
wandb_metrics["train/bleu_distribution"] = wandb.Histogram(
self.bleu_scores_buffer
)
if len(self.rouge_scores_buffer) > 10:
wandb_metrics["train/rouge_distribution"] = wandb.Histogram(
self.rouge_scores_buffer
)
# Add rollout table and category performance
wandb_metrics = await self.create_rollout_table(wandb_metrics)
wandb_metrics = await self.create_category_performance_table(wandb_metrics)