mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Add cycling curriculum, difficulty threshold, update datadumps
This commit is contained in:
parent
ab9e90a173
commit
ad1bdf7f80
2 changed files with 737 additions and 60 deletions
|
|
@ -1,14 +1,16 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import wandb
|
||||
from datasets import Dataset, load_dataset
|
||||
from langdetect import LangDetectException, detect
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
import wandb
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
|
|
@ -36,6 +38,31 @@ class IFConfig(BaseEnvConfig):
|
|||
test_set_ratio: float = Field(
|
||||
0.05, description="The ratio of the selected dataset for testing"
|
||||
)
|
||||
dump_rollouts: bool = Field(
|
||||
False, description="Whether to dump successful rollouts to JSONL files"
|
||||
)
|
||||
dump_failed_rollouts: bool = Field(
|
||||
False,
|
||||
description="Whether to dump failed rollouts (all 0 scores) to JSONL files for debugging",
|
||||
)
|
||||
rollout_save_score_threshold: float = Field(
|
||||
0.7, description="Minimum score threshold for saving rollouts to data dumps"
|
||||
)
|
||||
max_group_average_for_training: float = Field(
|
||||
0.75,
|
||||
description="Maximum group average score to use for training (skip groups that are too easy)",
|
||||
)
|
||||
dataset_shuffle_seed: int = Field(
|
||||
42, description="Seed for shuffling the dataset during setup"
|
||||
)
|
||||
resume_from_unsolved_dataset: Optional[str] = Field(
|
||||
None,
|
||||
description="Path to a remaining_unsolved.jsonl file to resume training from specific unsolved items",
|
||||
)
|
||||
suppress_base_env_logs: bool = Field(
|
||||
default=True,
|
||||
description="Suppress verbose base environment logs (like status dict updates).",
|
||||
)
|
||||
|
||||
|
||||
class InstructionFollowingEnv(BaseEnv):
|
||||
|
|
@ -53,6 +80,42 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
self.eval_metrics = list()
|
||||
self.rollouts_for_wandb = []
|
||||
|
||||
# Data dumping infrastructure
|
||||
self.rollouts_to_save_buffer = []
|
||||
self.failed_rollouts_to_save_buffer = []
|
||||
self.run_uuid = str(uuid.uuid4())[:8]
|
||||
self.save_file_batch_num = 0
|
||||
self.failed_save_file_batch_num = 0
|
||||
|
||||
# Adaptive curriculum: cycling queue for unsolved items
|
||||
self.active_train_queue = [] # Items currently in circulation
|
||||
self.solved_items = [] # Items that have been solved (removed from circulation)
|
||||
self.item_attempt_counts = (
|
||||
{}
|
||||
) # Track how many times each item has been attempted
|
||||
|
||||
# Create data dumps directory
|
||||
self.datadumps_dir = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "data_dumps"
|
||||
)
|
||||
|
||||
# Create datasets directory for curriculum state dumps
|
||||
self.datasets_dir = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "datasets"
|
||||
)
|
||||
|
||||
# Validate configuration for potential conflicts
|
||||
self._validate_config()
|
||||
|
||||
# Configure logging suppression
|
||||
if self.config.suppress_base_env_logs:
|
||||
import logging
|
||||
|
||||
# Suppress specific loggers that are too verbose
|
||||
logging.getLogger("atroposlib.envs.base").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
@classmethod
|
||||
def config_init(
|
||||
self,
|
||||
|
|
@ -74,6 +137,13 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
dataset_name="allenai/RLVR-IFeval", # Default dataset
|
||||
dataset_config_name=None, # RLVR-IFeval doesn't have a specific config name, uses 'default'
|
||||
test_set_ratio=0.05, # The ratio of the selelcted dataset in %
|
||||
dump_rollouts=False, # Enable data dumping if needed
|
||||
dump_failed_rollouts=False, # Enable failed rollout dumping for debugging
|
||||
rollout_save_score_threshold=0.7, # Save rollouts with score >= 0.7
|
||||
max_group_average_for_training=0.75, # Skip groups that are too easy for training
|
||||
dataset_shuffle_seed=42, # Seed for dataset shuffling
|
||||
resume_from_unsolved_dataset=None, # Path to resume from unsolved items
|
||||
suppress_base_env_logs=True, # Suppress verbose base environment logs
|
||||
)
|
||||
# Server configurations can be similar to SingleToolCallingEnv or adjusted
|
||||
server_configs = [
|
||||
|
|
@ -111,6 +181,31 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
except ZeroDivisionError:
|
||||
pass # Buffer might be empty
|
||||
|
||||
# Add adaptive curriculum metrics
|
||||
total_items = len(self.active_train_queue) + len(self.solved_items)
|
||||
if total_items > 0:
|
||||
wandb_metrics["curriculum/active_items"] = len(self.active_train_queue)
|
||||
wandb_metrics["curriculum/solved_items"] = len(self.solved_items)
|
||||
wandb_metrics["curriculum/percent_solved"] = (
|
||||
len(self.solved_items) / total_items
|
||||
)
|
||||
wandb_metrics["curriculum/total_items"] = total_items
|
||||
|
||||
# Average attempt count for items still in circulation
|
||||
if self.item_attempt_counts:
|
||||
active_attempts = [
|
||||
count
|
||||
for item_id, count in self.item_attempt_counts.items()
|
||||
if any(
|
||||
f"{item['func_name']}_{hash(str(item)) % 100000}" == item_id
|
||||
for item in self.active_train_queue
|
||||
)
|
||||
]
|
||||
if active_attempts:
|
||||
wandb_metrics["curriculum/avg_attempts_active"] = sum(
|
||||
active_attempts
|
||||
) / len(active_attempts)
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
|
|
@ -278,7 +373,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
f"due to previous errors."
|
||||
)
|
||||
|
||||
full_dataset = full_dataset.shuffle(seed=42)
|
||||
full_dataset = full_dataset.shuffle(seed=self.config.dataset_shuffle_seed)
|
||||
|
||||
actual_test_size = self.config.test_set_ratio # Read from config
|
||||
num_items = len(full_dataset)
|
||||
|
|
@ -328,9 +423,46 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
self.test = self.train
|
||||
|
||||
self.iter = 0
|
||||
|
||||
# Initialize the adaptive curriculum queue
|
||||
if self.config.resume_from_unsolved_dataset:
|
||||
print(
|
||||
f"🔄 Resume mode: Loading unsolved items from {self.config.resume_from_unsolved_dataset}"
|
||||
)
|
||||
print(
|
||||
f" Note: This will override the dataset_name '{self.config.dataset_name}' for training items"
|
||||
)
|
||||
print(f" Test set will still use items from '{self.config.dataset_name}'")
|
||||
await self._load_from_unsolved_dataset()
|
||||
else:
|
||||
# Initialize with all training items
|
||||
self.active_train_queue = list(self.train)
|
||||
self.solved_items = []
|
||||
self.item_attempt_counts = {}
|
||||
|
||||
print(
|
||||
f"Dataset setup complete. Train size: {len(self.train)}, Test size: {len(self.test)}"
|
||||
)
|
||||
print(
|
||||
f"Adaptive curriculum initialized with {len(self.active_train_queue)} items in active queue"
|
||||
)
|
||||
|
||||
def _validate_config(self):
|
||||
"""Validate configuration for potential conflicts and warn user."""
|
||||
if self.config.resume_from_unsolved_dataset and self.config.dataset_name:
|
||||
print("⚠️ Configuration Notice:")
|
||||
print(
|
||||
f" Both 'dataset_name' ({self.config.dataset_name}) and 'resume_from_unsolved_dataset' are set"
|
||||
)
|
||||
print(" Behavior:")
|
||||
print(
|
||||
" - Training items: Will come from the resume file (overrides dataset_name)"
|
||||
)
|
||||
print(" - Test/eval items: Will come from dataset_name")
|
||||
print(
|
||||
" - This is useful for resuming training while keeping consistent evaluation"
|
||||
)
|
||||
print()
|
||||
|
||||
async def _get_score_from_verifier(
|
||||
self, model_response_text: str, func_name: str, args: Dict
|
||||
|
|
@ -448,10 +580,6 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
func_name = test_item["func_name"]
|
||||
args_for_verifier = test_item["args"]
|
||||
|
||||
print(
|
||||
f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}..."
|
||||
) # DEBUG
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
messages.append({"role": "user", "content": instruction_prompt_text})
|
||||
|
||||
|
|
@ -459,9 +587,6 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
print(
|
||||
f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}..."
|
||||
) # DEBUG
|
||||
completion = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
n=1,
|
||||
|
|
@ -469,7 +594,6 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
temperature=0.2, # Temperature for eval, can be 0 for deterministic
|
||||
split="eval",
|
||||
)
|
||||
print("DEBUG: Received completion in rollout_and_score_eval.") # DEBUG
|
||||
|
||||
model_response_text = completion.choices[0].text
|
||||
score_value = await self._get_score_from_verifier(
|
||||
|
|
@ -487,12 +611,12 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
self.eval_metrics.append(("eval/percent_correct", 0.0))
|
||||
return
|
||||
|
||||
print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG
|
||||
print(f"Starting evaluation on {len(self.test)} items...")
|
||||
eval_tasks = []
|
||||
for test_item_dict in self.test: # self.test contains dicts after setup
|
||||
eval_tasks.append(self.rollout_and_score_eval(test_item_dict))
|
||||
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating")
|
||||
|
||||
if not scores: # If gather returns empty list
|
||||
percent_correct = 0.0
|
||||
|
|
@ -500,14 +624,13 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
percent_correct = sum(scores) / len(scores)
|
||||
|
||||
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
||||
print(f"Evaluation percent correct: {percent_correct}")
|
||||
print(f"Evaluation finished. Percent correct: {percent_correct:.4f}")
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List]:
|
||||
# item = (prompt_messages_tuple, answer_info_dict)
|
||||
# answer_info_dict = {"func_name": ..., "args": ...}
|
||||
print(f"DEBUG: Entering collect_trajectories. Item: {str(item)}") # DEBUG
|
||||
prompt_messages_list = [dict(msg_fset) for msg_fset in item[0]]
|
||||
answer_info = item[1]
|
||||
|
||||
|
|
@ -515,9 +638,6 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
prompt_messages_list, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
print(
|
||||
f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}..."
|
||||
) # DEBUG
|
||||
try:
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
|
|
@ -525,14 +645,8 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.8, # Temperature for diverse responses during training rollouts
|
||||
)
|
||||
print(
|
||||
f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories."
|
||||
) # DEBUG
|
||||
except Exception as e:
|
||||
print(
|
||||
f"ERROR: Exception during self.server.completion in collect_trajectories: {e}"
|
||||
) # DEBUG
|
||||
# Depending on the desired behavior, you might want to return None or raise the exception
|
||||
print(f"ERROR: Exception during completion generation: {e}")
|
||||
return None, []
|
||||
|
||||
to_score_list = []
|
||||
|
|
@ -546,23 +660,73 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
if not to_score_list:
|
||||
return None, []
|
||||
|
||||
print(
|
||||
f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories."
|
||||
) # DEBUG
|
||||
scored_data = await self.score(to_score_list)
|
||||
|
||||
# Handle adaptive curriculum: decide whether to keep item in circulation
|
||||
if scored_data and scored_data.get("scores"):
|
||||
group_average_score = sum(scored_data["scores"]) / len(
|
||||
scored_data["scores"]
|
||||
)
|
||||
self._handle_item_result(item, group_average_score)
|
||||
elif scored_data is None:
|
||||
# If scored_data is None, it might be because the group was skipped for being too easy
|
||||
# We need to calculate the scores ourselves to handle the item properly
|
||||
temp_scores = []
|
||||
for trajectory_messages, answer_info in to_score_list:
|
||||
model_response_text = trajectory_messages[-1]["content"]
|
||||
func_name = answer_info["func_name"]
|
||||
args_for_verifier = answer_info["args"]
|
||||
|
||||
# Get score (1.0 for correct, 0.0 for incorrect from verifier)
|
||||
score_value = await self._get_score_from_verifier(
|
||||
model_response_text, func_name, args_for_verifier
|
||||
)
|
||||
reward = 1.0 if score_value == 1.0 else 0
|
||||
temp_scores.append(reward)
|
||||
|
||||
if temp_scores:
|
||||
group_average_score = sum(temp_scores) / len(temp_scores)
|
||||
self._handle_item_result(item, group_average_score)
|
||||
|
||||
to_backlog = [] # Backlog not currently used but part of signature
|
||||
|
||||
print(
|
||||
f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}"
|
||||
) # DEBUG
|
||||
return scored_data, to_backlog
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
data["save_file_batch_num"] = self.save_file_batch_num
|
||||
data["failed_save_file_batch_num"] = self.failed_save_file_batch_num
|
||||
# Save adaptive curriculum state
|
||||
data["active_train_queue"] = self.active_train_queue
|
||||
data["solved_items"] = self.solved_items
|
||||
data["item_attempt_counts"] = self.item_attempt_counts
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def close(self):
|
||||
"""Save any remaining rollouts and curriculum state before closing."""
|
||||
if self.config.dump_rollouts and self.rollouts_to_save_buffer:
|
||||
print(
|
||||
f"Saving {len(self.rollouts_to_save_buffer)} remaining rollouts before closing..."
|
||||
)
|
||||
await self._save_rollouts_to_jsonl()
|
||||
|
||||
if self.config.dump_failed_rollouts and self.failed_rollouts_to_save_buffer:
|
||||
print(
|
||||
f"Saving {len(self.failed_rollouts_to_save_buffer)} remaining failed rollouts before closing..."
|
||||
)
|
||||
await self._save_failed_rollouts_to_jsonl()
|
||||
|
||||
# Save final curriculum state
|
||||
if self.active_train_queue:
|
||||
print(
|
||||
f"Saving final curriculum state with {len(self.active_train_queue)} unsolved items..."
|
||||
)
|
||||
await self._dump_active_queue_dataset()
|
||||
|
||||
await super().close()
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data: List[Tuple[tuple, Dict]]
|
||||
) -> Optional[ScoredDataGroup]:
|
||||
|
|
@ -583,6 +747,10 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
|
||||
random.shuffle(rollout_group_data) # Shuffle to avoid bias
|
||||
|
||||
# Data dumping: collect rollouts for saving (group format)
|
||||
rollouts_for_this_group = []
|
||||
failed_rollouts_for_this_group = []
|
||||
|
||||
for trajectory_item in rollout_group_data:
|
||||
full_trajectory_messages = trajectory_item[0]
|
||||
answer_info = trajectory_item[1] # {"func_name": ..., "args": ...}
|
||||
|
|
@ -599,6 +767,24 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
# Map to reward: 1.0 for correct, 0 for incorrect
|
||||
reward = 1.0 if score_value == 1.0 else 0
|
||||
|
||||
# Prepare structured conversation for data dumping
|
||||
conversation = [dict(msg) for msg in full_trajectory_messages]
|
||||
|
||||
# Create rollout dict for this specific rollout in the group
|
||||
rollout_dict = {
|
||||
"conversation": conversation,
|
||||
"score": reward,
|
||||
}
|
||||
|
||||
# Collect rollouts for this group based on score and config
|
||||
if (
|
||||
self.config.dump_rollouts
|
||||
and reward >= self.config.rollout_save_score_threshold
|
||||
):
|
||||
rollouts_for_this_group.append(rollout_dict)
|
||||
elif self.config.dump_failed_rollouts and reward == 0:
|
||||
failed_rollouts_for_this_group.append(rollout_dict)
|
||||
|
||||
# Tokenize the conversation for PPO training
|
||||
# Ensure full_trajectory_messages is a list of dicts
|
||||
list_of_dicts_trajectory = [dict(msg) for msg in full_trajectory_messages]
|
||||
|
|
@ -623,6 +809,128 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
if not scores_container["tokens"]: # No valid items collected
|
||||
return None
|
||||
|
||||
# Calculate group average score for difficulty filtering and logging
|
||||
current_scores = scores_container.get("scores", [])
|
||||
if current_scores:
|
||||
average_score = sum(current_scores) / len(current_scores)
|
||||
# Get task info from the first rollout's answer_info
|
||||
answer_info = rollout_group_data[0][1] if rollout_group_data else {}
|
||||
func_name = answer_info.get("func_name", "unknown_task")
|
||||
|
||||
# Check if group is too easy for training (but still allow data dumping)
|
||||
if average_score > self.config.max_group_average_for_training:
|
||||
print(
|
||||
f"Task: {func_name} | Group average score: {average_score:.4f} (SKIPPED - too easy for training, threshold: {self.config.max_group_average_for_training})" # noqa
|
||||
)
|
||||
|
||||
# Still handle data dumping for groups that are too easy for training
|
||||
# but might be useful for analysis
|
||||
if (
|
||||
rollouts_for_this_group
|
||||
and average_score
|
||||
<= self.config.max_group_average_for_training + 0.1
|
||||
): # Small buffer for data collection
|
||||
# Extract item info for the group - get from first rollout's answer_info
|
||||
answer_info = rollout_group_data[0][1]
|
||||
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
|
||||
|
||||
group_data_to_save = {
|
||||
"item_id": item_id,
|
||||
"rollouts": rollouts_for_this_group,
|
||||
"constraint_details": answer_info, # Store group-level metadata
|
||||
"group_average_score": average_score, # Add group average for analysis
|
||||
"skipped_for_training": True, # Mark as skipped for training
|
||||
}
|
||||
self.rollouts_to_save_buffer.append(group_data_to_save)
|
||||
|
||||
if failed_rollouts_for_this_group:
|
||||
# Extract item info for the failed group
|
||||
answer_info = rollout_group_data[0][1]
|
||||
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
|
||||
|
||||
failed_group_data_to_save = {
|
||||
"item_id": item_id,
|
||||
"rollouts": failed_rollouts_for_this_group,
|
||||
"constraint_details": answer_info, # Store group-level metadata
|
||||
"group_average_score": average_score, # Add group average for analysis
|
||||
"skipped_for_training": True, # Mark as skipped for training
|
||||
}
|
||||
self.failed_rollouts_to_save_buffer.append(
|
||||
failed_group_data_to_save
|
||||
)
|
||||
|
||||
# Save rollouts if buffer is getting large (batch processing)
|
||||
if (
|
||||
self.config.dump_rollouts
|
||||
and len(self.rollouts_to_save_buffer) >= 100
|
||||
):
|
||||
await self._save_rollouts_to_jsonl()
|
||||
if (
|
||||
self.config.dump_failed_rollouts
|
||||
and len(self.failed_rollouts_to_save_buffer) >= 50
|
||||
):
|
||||
await self._save_failed_rollouts_to_jsonl()
|
||||
|
||||
return None # Skip this group for training
|
||||
|
||||
log_message = (
|
||||
f"Task: {func_name} | Group average score: {average_score:.4f}"
|
||||
)
|
||||
if all(s >= 0.5 for s in current_scores):
|
||||
print(f"{log_message} (All correct in this group!)")
|
||||
elif all(s == 0.0 for s in current_scores):
|
||||
print(f"{log_message} (All failed - format/constraint violations!)")
|
||||
elif all(s < 0.5 for s in current_scores):
|
||||
print(f"{log_message} (All incorrect but some partial credit!)")
|
||||
else:
|
||||
print(log_message)
|
||||
|
||||
# Create group data structure and add to buffers for data dumping (for training groups)
|
||||
if rollouts_for_this_group:
|
||||
# Extract item info for the group - get from first rollout's answer_info
|
||||
answer_info = rollout_group_data[0][1]
|
||||
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
|
||||
|
||||
group_data_to_save = {
|
||||
"item_id": item_id,
|
||||
"rollouts": rollouts_for_this_group,
|
||||
"constraint_details": answer_info, # Store group-level metadata
|
||||
"group_average_score": (
|
||||
current_scores[0]
|
||||
if len(current_scores) == 1
|
||||
else sum(current_scores) / len(current_scores)
|
||||
), # Add group average for analysis
|
||||
"skipped_for_training": False, # Mark as used for training
|
||||
}
|
||||
self.rollouts_to_save_buffer.append(group_data_to_save)
|
||||
|
||||
if failed_rollouts_for_this_group:
|
||||
# Extract item info for the failed group
|
||||
answer_info = rollout_group_data[0][1]
|
||||
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
|
||||
|
||||
failed_group_data_to_save = {
|
||||
"item_id": item_id,
|
||||
"rollouts": failed_rollouts_for_this_group,
|
||||
"constraint_details": answer_info, # Store group-level metadata
|
||||
"group_average_score": (
|
||||
current_scores[0]
|
||||
if len(current_scores) == 1
|
||||
else sum(current_scores) / len(current_scores)
|
||||
), # Add group average for analysis
|
||||
"skipped_for_training": False, # Mark as used for training
|
||||
}
|
||||
self.failed_rollouts_to_save_buffer.append(failed_group_data_to_save)
|
||||
|
||||
# Save rollouts if buffer is getting large (batch processing)
|
||||
if self.config.dump_rollouts and len(self.rollouts_to_save_buffer) >= 100:
|
||||
await self._save_rollouts_to_jsonl()
|
||||
if (
|
||||
self.config.dump_failed_rollouts
|
||||
and len(self.failed_rollouts_to_save_buffer) >= 50
|
||||
):
|
||||
await self._save_failed_rollouts_to_jsonl()
|
||||
|
||||
# Record success rate for logging (based on positive rewards)
|
||||
for rwd in scores_container["scores"]:
|
||||
self.percent_correct_buffer.append(
|
||||
|
|
@ -665,26 +973,48 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
return scores_container
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
# Fetches the next preprocessed item from the training set
|
||||
if not self.train or len(self.train) == 0:
|
||||
# This case should be handled by setup, but as a safeguard:
|
||||
print("Error: Training data is empty in get_next_item.")
|
||||
# Return a dummy item to prevent crashes, though this indicates a setup issue
|
||||
dummy_prompt_messages = (
|
||||
frozenset({"role": "system", "content": system_prompt}.items()),
|
||||
frozenset(
|
||||
{"role": "user", "content": "Dummy instruction: say hello."}.items()
|
||||
),
|
||||
)
|
||||
dummy_answer_info = {
|
||||
"func_name": "verify_keywords",
|
||||
"args": {"keyword_list": ["hello"]},
|
||||
}
|
||||
return (dummy_prompt_messages, dummy_answer_info)
|
||||
# Fetches the next item from the adaptive curriculum queue
|
||||
if not self.active_train_queue:
|
||||
# If active queue is empty, check if we have any items left
|
||||
if not self.solved_items:
|
||||
print("Error: No training data available in get_next_item.")
|
||||
# Return a dummy item to prevent crashes
|
||||
dummy_prompt_messages = (
|
||||
frozenset({"role": "system", "content": system_prompt}.items()),
|
||||
frozenset(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Dummy instruction: say hello.",
|
||||
}.items()
|
||||
),
|
||||
)
|
||||
dummy_answer_info = {
|
||||
"func_name": "verify_keywords",
|
||||
"args": {"keyword_list": ["hello"]},
|
||||
}
|
||||
return (dummy_prompt_messages, dummy_answer_info)
|
||||
else:
|
||||
# All items have been solved! Reset the queue with solved items for continued training
|
||||
print(
|
||||
f"🎉 All {len(self.solved_items)} items have been solved! Resetting queue for continued training..."
|
||||
)
|
||||
self.active_train_queue = list(self.solved_items)
|
||||
self.solved_items = []
|
||||
# Reset attempt counts for the new cycle
|
||||
self.item_attempt_counts = {}
|
||||
|
||||
raw_item = self.train[self.iter % len(self.train)] # raw_item is a dict
|
||||
# Get the next item from the front of the active queue
|
||||
raw_item = self.active_train_queue.pop(0)
|
||||
self.iter += 1
|
||||
|
||||
# Create a unique identifier for this item for tracking
|
||||
item_id = f"{raw_item['func_name']}_{hash(str(raw_item)) % 100000}"
|
||||
|
||||
# Track attempt count
|
||||
if item_id not in self.item_attempt_counts:
|
||||
self.item_attempt_counts[item_id] = 0
|
||||
self.item_attempt_counts[item_id] += 1
|
||||
|
||||
instruction_prompt_text = raw_item["prompt"]
|
||||
|
||||
# Construct messages for the LLM (prompt tuple part of Item)
|
||||
|
|
@ -698,6 +1028,10 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
answer_info = {
|
||||
"func_name": raw_item["func_name"],
|
||||
"args": raw_item["args"],
|
||||
# Add item tracking info
|
||||
"item_id": item_id,
|
||||
"raw_item": raw_item, # Store the full item for queue management
|
||||
"attempt_count": self.item_attempt_counts[item_id],
|
||||
# Optionally include other info for logging/debugging if needed from raw_item
|
||||
"original_constraints_for_logging": raw_item.get(
|
||||
"original_constraints", ""
|
||||
|
|
@ -707,8 +1041,201 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
),
|
||||
}
|
||||
|
||||
# Dump active queue every 100 iterations for resumability
|
||||
if self.iter % 100 == 0 and self.iter > 0:
|
||||
await self._dump_active_queue_dataset()
|
||||
|
||||
return (prompt_messages_tuple, answer_info)
|
||||
|
||||
async def _dump_active_queue_dataset(self):
|
||||
"""
|
||||
Dumps the current active queue to a JSONL file so training can be resumed
|
||||
from the unsolved items if the environment is shut down.
|
||||
"""
|
||||
if not self.active_train_queue:
|
||||
print("No active items to dump - all items have been solved!")
|
||||
return
|
||||
|
||||
try:
|
||||
if not os.path.exists(self.datasets_dir):
|
||||
os.makedirs(self.datasets_dir)
|
||||
print(f"Created datasets directory: {self.datasets_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error creating datasets directory {self.datasets_dir}: {e}")
|
||||
return
|
||||
|
||||
filename = os.path.join(self.datasets_dir, "remaining_unsolved.jsonl")
|
||||
|
||||
try:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for item in self.active_train_queue:
|
||||
# Add metadata about the current state
|
||||
item_with_metadata = dict(item)
|
||||
item_id = f"{item['func_name']}_{hash(str(item)) % 100000}"
|
||||
item_with_metadata["_curriculum_metadata"] = {
|
||||
"item_id": item_id,
|
||||
"attempt_count": self.item_attempt_counts.get(item_id, 0),
|
||||
"queue_position": self.active_train_queue.index(item),
|
||||
"total_active": len(self.active_train_queue),
|
||||
"total_solved": len(self.solved_items),
|
||||
"iteration_dumped": self.iter,
|
||||
}
|
||||
json.dump(item_with_metadata, f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
|
||||
print(
|
||||
f"📁 Dumped {len(self.active_train_queue)} unsolved items to {filename} (iteration {self.iter})"
|
||||
)
|
||||
print(
|
||||
f" Queue status: {len(self.active_train_queue)} active, {len(self.solved_items)} solved"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error dumping active queue to {filename}: {e}")
|
||||
|
||||
async def _load_from_unsolved_dataset(self):
|
||||
"""
|
||||
Load the active queue from a previously saved remaining_unsolved.jsonl file.
|
||||
This allows resuming training from where it left off.
|
||||
|
||||
Note: When resuming, the training items come from the saved file, but the test set
|
||||
still comes from the current dataset_name configuration.
|
||||
"""
|
||||
try:
|
||||
with open(
|
||||
self.config.resume_from_unsolved_dataset, "r", encoding="utf-8"
|
||||
) as f:
|
||||
loaded_items = []
|
||||
loaded_attempt_counts = {}
|
||||
original_dataset_info = None
|
||||
|
||||
for line in f:
|
||||
item_data = json.loads(line.strip())
|
||||
|
||||
# Extract curriculum metadata if present
|
||||
metadata = item_data.pop("_curriculum_metadata", {})
|
||||
item_id = metadata.get("item_id")
|
||||
attempt_count = metadata.get("attempt_count", 0)
|
||||
|
||||
# Store info about the original dataset for validation
|
||||
if original_dataset_info is None and "iteration_dumped" in metadata:
|
||||
original_dataset_info = {
|
||||
"total_active_at_dump": metadata.get("total_active"),
|
||||
"total_solved_at_dump": metadata.get("total_solved"),
|
||||
"iteration_dumped": metadata.get("iteration_dumped"),
|
||||
}
|
||||
|
||||
if item_id and attempt_count > 0:
|
||||
loaded_attempt_counts[item_id] = attempt_count
|
||||
|
||||
# Validate that the item has the expected structure
|
||||
required_fields = ["prompt", "func_name", "args"]
|
||||
if not all(field in item_data for field in required_fields):
|
||||
print(
|
||||
f"Warning: Skipping malformed item missing required fields: {list(item_data.keys())}"
|
||||
)
|
||||
continue
|
||||
|
||||
loaded_items.append(item_data)
|
||||
|
||||
if not loaded_items:
|
||||
raise ValueError("No valid items found in resume file")
|
||||
|
||||
self.active_train_queue = loaded_items
|
||||
self.solved_items = [] # Start with no solved items when resuming
|
||||
self.item_attempt_counts = loaded_attempt_counts
|
||||
|
||||
print(
|
||||
f"📂 Loaded {len(loaded_items)} unsolved items from {self.config.resume_from_unsolved_dataset}"
|
||||
)
|
||||
if loaded_attempt_counts:
|
||||
avg_attempts = sum(loaded_attempt_counts.values()) / len(
|
||||
loaded_attempt_counts
|
||||
)
|
||||
print(
|
||||
f" Restored attempt counts for {len(loaded_attempt_counts)} items (avg: {avg_attempts:.1f} attempts)" # noqa
|
||||
)
|
||||
|
||||
if original_dataset_info:
|
||||
print(
|
||||
f" Original dump info: {original_dataset_info['total_solved_at_dump']} solved, "
|
||||
f"{original_dataset_info['total_active_at_dump']} active at iteration {original_dataset_info['iteration_dumped']}" # noqa
|
||||
)
|
||||
|
||||
# Validate compatibility with current dataset
|
||||
if hasattr(self, "train") and len(self.train) > 0:
|
||||
original_total = original_dataset_info.get(
|
||||
"total_active_at_dump", 0
|
||||
) + original_dataset_info.get("total_solved_at_dump", 0)
|
||||
current_total = len(self.train)
|
||||
|
||||
if original_total != current_total:
|
||||
print("⚠️ Warning: Dataset size mismatch!")
|
||||
print(
|
||||
f" Original dataset had {original_total} items, current dataset has {current_total} items" # noqa
|
||||
)
|
||||
print(
|
||||
" This might indicate different dataset versions or configurations"
|
||||
)
|
||||
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
f"❌ Resume file not found: {self.config.resume_from_unsolved_dataset}"
|
||||
)
|
||||
print("Falling back to full dataset initialization...")
|
||||
self.active_train_queue = list(self.train)
|
||||
self.solved_items = []
|
||||
self.item_attempt_counts = {}
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading from unsolved dataset: {e}")
|
||||
print("Falling back to full dataset initialization...")
|
||||
self.active_train_queue = list(self.train)
|
||||
self.solved_items = []
|
||||
self.item_attempt_counts = {}
|
||||
|
||||
def _handle_item_result(self, item: Item, group_average_score: float):
|
||||
"""
|
||||
Handle the result of an item based on its group average score.
|
||||
If solved (high score), remove from circulation.
|
||||
If not solved (low score), add back to the end of the queue.
|
||||
"""
|
||||
_, answer_info = item
|
||||
raw_item = answer_info.get("raw_item")
|
||||
item_id = answer_info.get("item_id")
|
||||
attempt_count = answer_info.get("attempt_count", 1)
|
||||
|
||||
if not raw_item or not item_id:
|
||||
return # Skip if we don't have the necessary info
|
||||
|
||||
# Define "solved" as either too easy for training OR very high performance
|
||||
is_solved = (
|
||||
group_average_score > self.config.max_group_average_for_training
|
||||
or group_average_score >= 0.9
|
||||
) # Very high performance threshold
|
||||
|
||||
if is_solved:
|
||||
# Item is solved - move to solved items (remove from circulation)
|
||||
self.solved_items.append(raw_item)
|
||||
status = "SOLVED - removed from circulation"
|
||||
if group_average_score > self.config.max_group_average_for_training:
|
||||
status += " (too easy)"
|
||||
else:
|
||||
status += " (mastered)"
|
||||
else:
|
||||
# Item not solved - add back to the end of the active queue
|
||||
self.active_train_queue.append(raw_item)
|
||||
status = "NOT SOLVED - added back to queue"
|
||||
|
||||
# Log the decision periodically or for items with many attempts
|
||||
if attempt_count % 5 == 1 or attempt_count <= 3 or is_solved:
|
||||
func_name = answer_info.get("func_name", "unknown")
|
||||
print(
|
||||
f"📚 Item {item_id} (attempt #{attempt_count}): {func_name} | Score: {group_average_score:.4f} | {status}" # noqa
|
||||
)
|
||||
print(
|
||||
f" Queue status: {len(self.active_train_queue)} active, {len(self.solved_items)} solved"
|
||||
)
|
||||
|
||||
async def add_rollouts_for_wandb(
|
||||
self,
|
||||
scored_data: ScoredDataGroup, # Assuming single ScoredDataGroup here
|
||||
|
|
@ -736,6 +1263,78 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
||||
self.rollouts_for_wandb.pop(0)
|
||||
|
||||
async def _save_rollouts_to_jsonl(self):
|
||||
"""Saves the buffered rollouts to a JSONL file in the datadumps directory."""
|
||||
if not self.rollouts_to_save_buffer:
|
||||
print("Warning: _save_rollouts_to_jsonl called but buffer is empty!")
|
||||
return
|
||||
|
||||
buffer_size = len(self.rollouts_to_save_buffer)
|
||||
print(f"Starting save of {buffer_size} rollout groups to JSONL file...")
|
||||
|
||||
try:
|
||||
if not os.path.exists(self.datadumps_dir):
|
||||
os.makedirs(self.datadumps_dir)
|
||||
print(f"Created directory: {self.datadumps_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error creating directory {self.datadumps_dir}: {e}")
|
||||
return
|
||||
|
||||
filename = os.path.join(
|
||||
self.datadumps_dir,
|
||||
f"instruction_following_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl",
|
||||
)
|
||||
|
||||
try:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for rollout_dict in self.rollouts_to_save_buffer:
|
||||
json.dump(rollout_dict, f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
|
||||
print(f"Successfully saved {buffer_size} rollout groups to {filename}")
|
||||
self.save_file_batch_num += 1
|
||||
self.rollouts_to_save_buffer.clear()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving rollouts to {filename}: {e}")
|
||||
|
||||
async def _save_failed_rollouts_to_jsonl(self):
|
||||
"""Saves the buffered failed rollouts to a JSONL file for debugging."""
|
||||
if not self.failed_rollouts_to_save_buffer:
|
||||
print("Warning: _save_failed_rollouts_to_jsonl called but buffer is empty!")
|
||||
return
|
||||
|
||||
buffer_size = len(self.failed_rollouts_to_save_buffer)
|
||||
print(f"Starting save of {buffer_size} failed rollout groups to JSONL file...")
|
||||
|
||||
try:
|
||||
if not os.path.exists(self.datadumps_dir):
|
||||
os.makedirs(self.datadumps_dir)
|
||||
print(f"Created directory: {self.datadumps_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error creating directory {self.datadumps_dir}: {e}")
|
||||
return
|
||||
|
||||
filename = os.path.join(
|
||||
self.datadumps_dir,
|
||||
f"instruction_following_failed_rollouts_{self.run_uuid}_{self.failed_save_file_batch_num:04d}.jsonl",
|
||||
)
|
||||
|
||||
try:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for rollout_dict in self.failed_rollouts_to_save_buffer:
|
||||
json.dump(rollout_dict, f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
|
||||
print(
|
||||
f"Successfully saved {buffer_size} failed rollout groups to {filename}"
|
||||
)
|
||||
self.failed_save_file_batch_num += 1
|
||||
self.failed_rollouts_to_save_buffer.clear()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving failed rollouts to {filename}: {e}")
|
||||
|
||||
|
||||
# ----- IFEval Verifier Functions and Map -----
|
||||
# adapted from https://github.com/allenai/open-instruct/blob/main/scripts/eval_constraints/if_functions.py
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue