Add cycling curriculum, difficulty threshold, update datadumps

This commit is contained in:
teknium1 2025-06-14 07:44:47 -07:00
parent ab9e90a173
commit ad1bdf7f80
2 changed files with 737 additions and 60 deletions

View file

@ -1,14 +1,16 @@
import json
import os
import random
import re
import uuid
from typing import Dict, List, Optional, Tuple
import wandb
from datasets import Dataset, load_dataset
from langdetect import LangDetectException, detect
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
@ -36,6 +38,31 @@ class IFConfig(BaseEnvConfig):
test_set_ratio: float = Field(
0.05, description="The ratio of the selected dataset for testing"
)
dump_rollouts: bool = Field(
False, description="Whether to dump successful rollouts to JSONL files"
)
dump_failed_rollouts: bool = Field(
False,
description="Whether to dump failed rollouts (all 0 scores) to JSONL files for debugging",
)
rollout_save_score_threshold: float = Field(
0.7, description="Minimum score threshold for saving rollouts to data dumps"
)
max_group_average_for_training: float = Field(
0.75,
description="Maximum group average score to use for training (skip groups that are too easy)",
)
dataset_shuffle_seed: int = Field(
42, description="Seed for shuffling the dataset during setup"
)
resume_from_unsolved_dataset: Optional[str] = Field(
None,
description="Path to a remaining_unsolved.jsonl file to resume training from specific unsolved items",
)
suppress_base_env_logs: bool = Field(
default=True,
description="Suppress verbose base environment logs (like status dict updates).",
)
class InstructionFollowingEnv(BaseEnv):
@ -53,6 +80,42 @@ class InstructionFollowingEnv(BaseEnv):
self.eval_metrics = list()
self.rollouts_for_wandb = []
# Data dumping infrastructure
self.rollouts_to_save_buffer = []
self.failed_rollouts_to_save_buffer = []
self.run_uuid = str(uuid.uuid4())[:8]
self.save_file_batch_num = 0
self.failed_save_file_batch_num = 0
# Adaptive curriculum: cycling queue for unsolved items
self.active_train_queue = [] # Items currently in circulation
self.solved_items = [] # Items that have been solved (removed from circulation)
self.item_attempt_counts = (
{}
) # Track how many times each item has been attempted
# Create data dumps directory
self.datadumps_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "data_dumps"
)
# Create datasets directory for curriculum state dumps
self.datasets_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "datasets"
)
# Validate configuration for potential conflicts
self._validate_config()
# Configure logging suppression
if self.config.suppress_base_env_logs:
import logging
# Suppress specific loggers that are too verbose
logging.getLogger("atroposlib.envs.base").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
@classmethod
def config_init(
self,
@ -74,6 +137,13 @@ class InstructionFollowingEnv(BaseEnv):
dataset_name="allenai/RLVR-IFeval", # Default dataset
dataset_config_name=None, # RLVR-IFeval doesn't have a specific config name, uses 'default'
test_set_ratio=0.05, # The ratio of the selelcted dataset in %
dump_rollouts=False, # Enable data dumping if needed
dump_failed_rollouts=False, # Enable failed rollout dumping for debugging
rollout_save_score_threshold=0.7, # Save rollouts with score >= 0.7
max_group_average_for_training=0.75, # Skip groups that are too easy for training
dataset_shuffle_seed=42, # Seed for dataset shuffling
resume_from_unsolved_dataset=None, # Path to resume from unsolved items
suppress_base_env_logs=True, # Suppress verbose base environment logs
)
# Server configurations can be similar to SingleToolCallingEnv or adjusted
server_configs = [
@ -111,6 +181,31 @@ class InstructionFollowingEnv(BaseEnv):
except ZeroDivisionError:
pass # Buffer might be empty
# Add adaptive curriculum metrics
total_items = len(self.active_train_queue) + len(self.solved_items)
if total_items > 0:
wandb_metrics["curriculum/active_items"] = len(self.active_train_queue)
wandb_metrics["curriculum/solved_items"] = len(self.solved_items)
wandb_metrics["curriculum/percent_solved"] = (
len(self.solved_items) / total_items
)
wandb_metrics["curriculum/total_items"] = total_items
# Average attempt count for items still in circulation
if self.item_attempt_counts:
active_attempts = [
count
for item_id, count in self.item_attempt_counts.items()
if any(
f"{item['func_name']}_{hash(str(item)) % 100000}" == item_id
for item in self.active_train_queue
)
]
if active_attempts:
wandb_metrics["curriculum/avg_attempts_active"] = sum(
active_attempts
) / len(active_attempts)
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
@ -278,7 +373,7 @@ class InstructionFollowingEnv(BaseEnv):
f"due to previous errors."
)
full_dataset = full_dataset.shuffle(seed=42)
full_dataset = full_dataset.shuffle(seed=self.config.dataset_shuffle_seed)
actual_test_size = self.config.test_set_ratio # Read from config
num_items = len(full_dataset)
@ -328,9 +423,46 @@ class InstructionFollowingEnv(BaseEnv):
self.test = self.train
self.iter = 0
# Initialize the adaptive curriculum queue
if self.config.resume_from_unsolved_dataset:
print(
f"🔄 Resume mode: Loading unsolved items from {self.config.resume_from_unsolved_dataset}"
)
print(
f" Note: This will override the dataset_name '{self.config.dataset_name}' for training items"
)
print(f" Test set will still use items from '{self.config.dataset_name}'")
await self._load_from_unsolved_dataset()
else:
# Initialize with all training items
self.active_train_queue = list(self.train)
self.solved_items = []
self.item_attempt_counts = {}
print(
f"Dataset setup complete. Train size: {len(self.train)}, Test size: {len(self.test)}"
)
print(
f"Adaptive curriculum initialized with {len(self.active_train_queue)} items in active queue"
)
def _validate_config(self):
"""Validate configuration for potential conflicts and warn user."""
if self.config.resume_from_unsolved_dataset and self.config.dataset_name:
print("⚠️ Configuration Notice:")
print(
f" Both 'dataset_name' ({self.config.dataset_name}) and 'resume_from_unsolved_dataset' are set"
)
print(" Behavior:")
print(
" - Training items: Will come from the resume file (overrides dataset_name)"
)
print(" - Test/eval items: Will come from dataset_name")
print(
" - This is useful for resuming training while keeping consistent evaluation"
)
print()
async def _get_score_from_verifier(
self, model_response_text: str, func_name: str, args: Dict
@ -448,10 +580,6 @@ class InstructionFollowingEnv(BaseEnv):
func_name = test_item["func_name"]
args_for_verifier = test_item["args"]
print(
f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}..."
) # DEBUG
messages = [{"role": "system", "content": system_prompt}]
messages.append({"role": "user", "content": instruction_prompt_text})
@ -459,9 +587,6 @@ class InstructionFollowingEnv(BaseEnv):
messages, add_generation_prompt=True, tokenize=False
)
print(
f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}..."
) # DEBUG
completion = await self.server.completion(
prompt=prompt_str,
n=1,
@ -469,7 +594,6 @@ class InstructionFollowingEnv(BaseEnv):
temperature=0.2, # Temperature for eval, can be 0 for deterministic
split="eval",
)
print("DEBUG: Received completion in rollout_and_score_eval.") # DEBUG
model_response_text = completion.choices[0].text
score_value = await self._get_score_from_verifier(
@ -487,12 +611,12 @@ class InstructionFollowingEnv(BaseEnv):
self.eval_metrics.append(("eval/percent_correct", 0.0))
return
print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG
print(f"Starting evaluation on {len(self.test)} items...")
eval_tasks = []
for test_item_dict in self.test: # self.test contains dicts after setup
eval_tasks.append(self.rollout_and_score_eval(test_item_dict))
scores = await tqdm_asyncio.gather(*eval_tasks)
scores = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating")
if not scores: # If gather returns empty list
percent_correct = 0.0
@ -500,14 +624,13 @@ class InstructionFollowingEnv(BaseEnv):
percent_correct = sum(scores) / len(scores)
self.eval_metrics.append(("eval/percent_correct", percent_correct))
print(f"Evaluation percent correct: {percent_correct}")
print(f"Evaluation finished. Percent correct: {percent_correct:.4f}")
async def collect_trajectories(
self, item: Item
) -> Tuple[Optional[ScoredDataGroup], List]:
# item = (prompt_messages_tuple, answer_info_dict)
# answer_info_dict = {"func_name": ..., "args": ...}
print(f"DEBUG: Entering collect_trajectories. Item: {str(item)}") # DEBUG
prompt_messages_list = [dict(msg_fset) for msg_fset in item[0]]
answer_info = item[1]
@ -515,9 +638,6 @@ class InstructionFollowingEnv(BaseEnv):
prompt_messages_list, add_generation_prompt=True, tokenize=False
)
print(
f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}..."
) # DEBUG
try:
completions = await self.server.completion(
prompt=prompt_str,
@ -525,14 +645,8 @@ class InstructionFollowingEnv(BaseEnv):
max_tokens=self.config.max_token_length,
temperature=0.8, # Temperature for diverse responses during training rollouts
)
print(
f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories."
) # DEBUG
except Exception as e:
print(
f"ERROR: Exception during self.server.completion in collect_trajectories: {e}"
) # DEBUG
# Depending on the desired behavior, you might want to return None or raise the exception
print(f"ERROR: Exception during completion generation: {e}")
return None, []
to_score_list = []
@ -546,23 +660,73 @@ class InstructionFollowingEnv(BaseEnv):
if not to_score_list:
return None, []
print(
f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories."
) # DEBUG
scored_data = await self.score(to_score_list)
# Handle adaptive curriculum: decide whether to keep item in circulation
if scored_data and scored_data.get("scores"):
group_average_score = sum(scored_data["scores"]) / len(
scored_data["scores"]
)
self._handle_item_result(item, group_average_score)
elif scored_data is None:
# If scored_data is None, it might be because the group was skipped for being too easy
# We need to calculate the scores ourselves to handle the item properly
temp_scores = []
for trajectory_messages, answer_info in to_score_list:
model_response_text = trajectory_messages[-1]["content"]
func_name = answer_info["func_name"]
args_for_verifier = answer_info["args"]
# Get score (1.0 for correct, 0.0 for incorrect from verifier)
score_value = await self._get_score_from_verifier(
model_response_text, func_name, args_for_verifier
)
reward = 1.0 if score_value == 1.0 else 0
temp_scores.append(reward)
if temp_scores:
group_average_score = sum(temp_scores) / len(temp_scores)
self._handle_item_result(item, group_average_score)
to_backlog = [] # Backlog not currently used but part of signature
print(
f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}"
) # DEBUG
return scored_data, to_backlog
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
data["save_file_batch_num"] = self.save_file_batch_num
data["failed_save_file_batch_num"] = self.failed_save_file_batch_num
# Save adaptive curriculum state
data["active_train_queue"] = self.active_train_queue
data["solved_items"] = self.solved_items
data["item_attempt_counts"] = self.item_attempt_counts
super().save_checkpoint(step, data)
async def close(self):
"""Save any remaining rollouts and curriculum state before closing."""
if self.config.dump_rollouts and self.rollouts_to_save_buffer:
print(
f"Saving {len(self.rollouts_to_save_buffer)} remaining rollouts before closing..."
)
await self._save_rollouts_to_jsonl()
if self.config.dump_failed_rollouts and self.failed_rollouts_to_save_buffer:
print(
f"Saving {len(self.failed_rollouts_to_save_buffer)} remaining failed rollouts before closing..."
)
await self._save_failed_rollouts_to_jsonl()
# Save final curriculum state
if self.active_train_queue:
print(
f"Saving final curriculum state with {len(self.active_train_queue)} unsolved items..."
)
await self._dump_active_queue_dataset()
await super().close()
async def score(
self, rollout_group_data: List[Tuple[tuple, Dict]]
) -> Optional[ScoredDataGroup]:
@ -583,6 +747,10 @@ class InstructionFollowingEnv(BaseEnv):
random.shuffle(rollout_group_data) # Shuffle to avoid bias
# Data dumping: collect rollouts for saving (group format)
rollouts_for_this_group = []
failed_rollouts_for_this_group = []
for trajectory_item in rollout_group_data:
full_trajectory_messages = trajectory_item[0]
answer_info = trajectory_item[1] # {"func_name": ..., "args": ...}
@ -599,6 +767,24 @@ class InstructionFollowingEnv(BaseEnv):
# Map to reward: 1.0 for correct, 0 for incorrect
reward = 1.0 if score_value == 1.0 else 0
# Prepare structured conversation for data dumping
conversation = [dict(msg) for msg in full_trajectory_messages]
# Create rollout dict for this specific rollout in the group
rollout_dict = {
"conversation": conversation,
"score": reward,
}
# Collect rollouts for this group based on score and config
if (
self.config.dump_rollouts
and reward >= self.config.rollout_save_score_threshold
):
rollouts_for_this_group.append(rollout_dict)
elif self.config.dump_failed_rollouts and reward == 0:
failed_rollouts_for_this_group.append(rollout_dict)
# Tokenize the conversation for PPO training
# Ensure full_trajectory_messages is a list of dicts
list_of_dicts_trajectory = [dict(msg) for msg in full_trajectory_messages]
@ -623,6 +809,128 @@ class InstructionFollowingEnv(BaseEnv):
if not scores_container["tokens"]: # No valid items collected
return None
# Calculate group average score for difficulty filtering and logging
current_scores = scores_container.get("scores", [])
if current_scores:
average_score = sum(current_scores) / len(current_scores)
# Get task info from the first rollout's answer_info
answer_info = rollout_group_data[0][1] if rollout_group_data else {}
func_name = answer_info.get("func_name", "unknown_task")
# Check if group is too easy for training (but still allow data dumping)
if average_score > self.config.max_group_average_for_training:
print(
f"Task: {func_name} | Group average score: {average_score:.4f} (SKIPPED - too easy for training, threshold: {self.config.max_group_average_for_training})" # noqa
)
# Still handle data dumping for groups that are too easy for training
# but might be useful for analysis
if (
rollouts_for_this_group
and average_score
<= self.config.max_group_average_for_training + 0.1
): # Small buffer for data collection
# Extract item info for the group - get from first rollout's answer_info
answer_info = rollout_group_data[0][1]
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
group_data_to_save = {
"item_id": item_id,
"rollouts": rollouts_for_this_group,
"constraint_details": answer_info, # Store group-level metadata
"group_average_score": average_score, # Add group average for analysis
"skipped_for_training": True, # Mark as skipped for training
}
self.rollouts_to_save_buffer.append(group_data_to_save)
if failed_rollouts_for_this_group:
# Extract item info for the failed group
answer_info = rollout_group_data[0][1]
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
failed_group_data_to_save = {
"item_id": item_id,
"rollouts": failed_rollouts_for_this_group,
"constraint_details": answer_info, # Store group-level metadata
"group_average_score": average_score, # Add group average for analysis
"skipped_for_training": True, # Mark as skipped for training
}
self.failed_rollouts_to_save_buffer.append(
failed_group_data_to_save
)
# Save rollouts if buffer is getting large (batch processing)
if (
self.config.dump_rollouts
and len(self.rollouts_to_save_buffer) >= 100
):
await self._save_rollouts_to_jsonl()
if (
self.config.dump_failed_rollouts
and len(self.failed_rollouts_to_save_buffer) >= 50
):
await self._save_failed_rollouts_to_jsonl()
return None # Skip this group for training
log_message = (
f"Task: {func_name} | Group average score: {average_score:.4f}"
)
if all(s >= 0.5 for s in current_scores):
print(f"{log_message} (All correct in this group!)")
elif all(s == 0.0 for s in current_scores):
print(f"{log_message} (All failed - format/constraint violations!)")
elif all(s < 0.5 for s in current_scores):
print(f"{log_message} (All incorrect but some partial credit!)")
else:
print(log_message)
# Create group data structure and add to buffers for data dumping (for training groups)
if rollouts_for_this_group:
# Extract item info for the group - get from first rollout's answer_info
answer_info = rollout_group_data[0][1]
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
group_data_to_save = {
"item_id": item_id,
"rollouts": rollouts_for_this_group,
"constraint_details": answer_info, # Store group-level metadata
"group_average_score": (
current_scores[0]
if len(current_scores) == 1
else sum(current_scores) / len(current_scores)
), # Add group average for analysis
"skipped_for_training": False, # Mark as used for training
}
self.rollouts_to_save_buffer.append(group_data_to_save)
if failed_rollouts_for_this_group:
# Extract item info for the failed group
answer_info = rollout_group_data[0][1]
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
failed_group_data_to_save = {
"item_id": item_id,
"rollouts": failed_rollouts_for_this_group,
"constraint_details": answer_info, # Store group-level metadata
"group_average_score": (
current_scores[0]
if len(current_scores) == 1
else sum(current_scores) / len(current_scores)
), # Add group average for analysis
"skipped_for_training": False, # Mark as used for training
}
self.failed_rollouts_to_save_buffer.append(failed_group_data_to_save)
# Save rollouts if buffer is getting large (batch processing)
if self.config.dump_rollouts and len(self.rollouts_to_save_buffer) >= 100:
await self._save_rollouts_to_jsonl()
if (
self.config.dump_failed_rollouts
and len(self.failed_rollouts_to_save_buffer) >= 50
):
await self._save_failed_rollouts_to_jsonl()
# Record success rate for logging (based on positive rewards)
for rwd in scores_container["scores"]:
self.percent_correct_buffer.append(
@ -665,26 +973,48 @@ class InstructionFollowingEnv(BaseEnv):
return scores_container
async def get_next_item(self) -> Item:
# Fetches the next preprocessed item from the training set
if not self.train or len(self.train) == 0:
# This case should be handled by setup, but as a safeguard:
print("Error: Training data is empty in get_next_item.")
# Return a dummy item to prevent crashes, though this indicates a setup issue
dummy_prompt_messages = (
frozenset({"role": "system", "content": system_prompt}.items()),
frozenset(
{"role": "user", "content": "Dummy instruction: say hello."}.items()
),
)
dummy_answer_info = {
"func_name": "verify_keywords",
"args": {"keyword_list": ["hello"]},
}
return (dummy_prompt_messages, dummy_answer_info)
# Fetches the next item from the adaptive curriculum queue
if not self.active_train_queue:
# If active queue is empty, check if we have any items left
if not self.solved_items:
print("Error: No training data available in get_next_item.")
# Return a dummy item to prevent crashes
dummy_prompt_messages = (
frozenset({"role": "system", "content": system_prompt}.items()),
frozenset(
{
"role": "user",
"content": "Dummy instruction: say hello.",
}.items()
),
)
dummy_answer_info = {
"func_name": "verify_keywords",
"args": {"keyword_list": ["hello"]},
}
return (dummy_prompt_messages, dummy_answer_info)
else:
# All items have been solved! Reset the queue with solved items for continued training
print(
f"🎉 All {len(self.solved_items)} items have been solved! Resetting queue for continued training..."
)
self.active_train_queue = list(self.solved_items)
self.solved_items = []
# Reset attempt counts for the new cycle
self.item_attempt_counts = {}
raw_item = self.train[self.iter % len(self.train)] # raw_item is a dict
# Get the next item from the front of the active queue
raw_item = self.active_train_queue.pop(0)
self.iter += 1
# Create a unique identifier for this item for tracking
item_id = f"{raw_item['func_name']}_{hash(str(raw_item)) % 100000}"
# Track attempt count
if item_id not in self.item_attempt_counts:
self.item_attempt_counts[item_id] = 0
self.item_attempt_counts[item_id] += 1
instruction_prompt_text = raw_item["prompt"]
# Construct messages for the LLM (prompt tuple part of Item)
@ -698,6 +1028,10 @@ class InstructionFollowingEnv(BaseEnv):
answer_info = {
"func_name": raw_item["func_name"],
"args": raw_item["args"],
# Add item tracking info
"item_id": item_id,
"raw_item": raw_item, # Store the full item for queue management
"attempt_count": self.item_attempt_counts[item_id],
# Optionally include other info for logging/debugging if needed from raw_item
"original_constraints_for_logging": raw_item.get(
"original_constraints", ""
@ -707,8 +1041,201 @@ class InstructionFollowingEnv(BaseEnv):
),
}
# Dump active queue every 100 iterations for resumability
if self.iter % 100 == 0 and self.iter > 0:
await self._dump_active_queue_dataset()
return (prompt_messages_tuple, answer_info)
async def _dump_active_queue_dataset(self):
"""
Dumps the current active queue to a JSONL file so training can be resumed
from the unsolved items if the environment is shut down.
"""
if not self.active_train_queue:
print("No active items to dump - all items have been solved!")
return
try:
if not os.path.exists(self.datasets_dir):
os.makedirs(self.datasets_dir)
print(f"Created datasets directory: {self.datasets_dir}")
except Exception as e:
print(f"Error creating datasets directory {self.datasets_dir}: {e}")
return
filename = os.path.join(self.datasets_dir, "remaining_unsolved.jsonl")
try:
with open(filename, "w", encoding="utf-8") as f:
for item in self.active_train_queue:
# Add metadata about the current state
item_with_metadata = dict(item)
item_id = f"{item['func_name']}_{hash(str(item)) % 100000}"
item_with_metadata["_curriculum_metadata"] = {
"item_id": item_id,
"attempt_count": self.item_attempt_counts.get(item_id, 0),
"queue_position": self.active_train_queue.index(item),
"total_active": len(self.active_train_queue),
"total_solved": len(self.solved_items),
"iteration_dumped": self.iter,
}
json.dump(item_with_metadata, f, ensure_ascii=False)
f.write("\n")
print(
f"📁 Dumped {len(self.active_train_queue)} unsolved items to {filename} (iteration {self.iter})"
)
print(
f" Queue status: {len(self.active_train_queue)} active, {len(self.solved_items)} solved"
)
except Exception as e:
print(f"Error dumping active queue to {filename}: {e}")
async def _load_from_unsolved_dataset(self):
"""
Load the active queue from a previously saved remaining_unsolved.jsonl file.
This allows resuming training from where it left off.
Note: When resuming, the training items come from the saved file, but the test set
still comes from the current dataset_name configuration.
"""
try:
with open(
self.config.resume_from_unsolved_dataset, "r", encoding="utf-8"
) as f:
loaded_items = []
loaded_attempt_counts = {}
original_dataset_info = None
for line in f:
item_data = json.loads(line.strip())
# Extract curriculum metadata if present
metadata = item_data.pop("_curriculum_metadata", {})
item_id = metadata.get("item_id")
attempt_count = metadata.get("attempt_count", 0)
# Store info about the original dataset for validation
if original_dataset_info is None and "iteration_dumped" in metadata:
original_dataset_info = {
"total_active_at_dump": metadata.get("total_active"),
"total_solved_at_dump": metadata.get("total_solved"),
"iteration_dumped": metadata.get("iteration_dumped"),
}
if item_id and attempt_count > 0:
loaded_attempt_counts[item_id] = attempt_count
# Validate that the item has the expected structure
required_fields = ["prompt", "func_name", "args"]
if not all(field in item_data for field in required_fields):
print(
f"Warning: Skipping malformed item missing required fields: {list(item_data.keys())}"
)
continue
loaded_items.append(item_data)
if not loaded_items:
raise ValueError("No valid items found in resume file")
self.active_train_queue = loaded_items
self.solved_items = [] # Start with no solved items when resuming
self.item_attempt_counts = loaded_attempt_counts
print(
f"📂 Loaded {len(loaded_items)} unsolved items from {self.config.resume_from_unsolved_dataset}"
)
if loaded_attempt_counts:
avg_attempts = sum(loaded_attempt_counts.values()) / len(
loaded_attempt_counts
)
print(
f" Restored attempt counts for {len(loaded_attempt_counts)} items (avg: {avg_attempts:.1f} attempts)" # noqa
)
if original_dataset_info:
print(
f" Original dump info: {original_dataset_info['total_solved_at_dump']} solved, "
f"{original_dataset_info['total_active_at_dump']} active at iteration {original_dataset_info['iteration_dumped']}" # noqa
)
# Validate compatibility with current dataset
if hasattr(self, "train") and len(self.train) > 0:
original_total = original_dataset_info.get(
"total_active_at_dump", 0
) + original_dataset_info.get("total_solved_at_dump", 0)
current_total = len(self.train)
if original_total != current_total:
print("⚠️ Warning: Dataset size mismatch!")
print(
f" Original dataset had {original_total} items, current dataset has {current_total} items" # noqa
)
print(
" This might indicate different dataset versions or configurations"
)
except FileNotFoundError:
print(
f"❌ Resume file not found: {self.config.resume_from_unsolved_dataset}"
)
print("Falling back to full dataset initialization...")
self.active_train_queue = list(self.train)
self.solved_items = []
self.item_attempt_counts = {}
except Exception as e:
print(f"❌ Error loading from unsolved dataset: {e}")
print("Falling back to full dataset initialization...")
self.active_train_queue = list(self.train)
self.solved_items = []
self.item_attempt_counts = {}
def _handle_item_result(self, item: Item, group_average_score: float):
"""
Handle the result of an item based on its group average score.
If solved (high score), remove from circulation.
If not solved (low score), add back to the end of the queue.
"""
_, answer_info = item
raw_item = answer_info.get("raw_item")
item_id = answer_info.get("item_id")
attempt_count = answer_info.get("attempt_count", 1)
if not raw_item or not item_id:
return # Skip if we don't have the necessary info
# Define "solved" as either too easy for training OR very high performance
is_solved = (
group_average_score > self.config.max_group_average_for_training
or group_average_score >= 0.9
) # Very high performance threshold
if is_solved:
# Item is solved - move to solved items (remove from circulation)
self.solved_items.append(raw_item)
status = "SOLVED - removed from circulation"
if group_average_score > self.config.max_group_average_for_training:
status += " (too easy)"
else:
status += " (mastered)"
else:
# Item not solved - add back to the end of the active queue
self.active_train_queue.append(raw_item)
status = "NOT SOLVED - added back to queue"
# Log the decision periodically or for items with many attempts
if attempt_count % 5 == 1 or attempt_count <= 3 or is_solved:
func_name = answer_info.get("func_name", "unknown")
print(
f"📚 Item {item_id} (attempt #{attempt_count}): {func_name} | Score: {group_average_score:.4f} | {status}" # noqa
)
print(
f" Queue status: {len(self.active_train_queue)} active, {len(self.solved_items)} solved"
)
async def add_rollouts_for_wandb(
self,
scored_data: ScoredDataGroup, # Assuming single ScoredDataGroup here
@ -736,6 +1263,78 @@ class InstructionFollowingEnv(BaseEnv):
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
async def _save_rollouts_to_jsonl(self):
"""Saves the buffered rollouts to a JSONL file in the datadumps directory."""
if not self.rollouts_to_save_buffer:
print("Warning: _save_rollouts_to_jsonl called but buffer is empty!")
return
buffer_size = len(self.rollouts_to_save_buffer)
print(f"Starting save of {buffer_size} rollout groups to JSONL file...")
try:
if not os.path.exists(self.datadumps_dir):
os.makedirs(self.datadumps_dir)
print(f"Created directory: {self.datadumps_dir}")
except Exception as e:
print(f"Error creating directory {self.datadumps_dir}: {e}")
return
filename = os.path.join(
self.datadumps_dir,
f"instruction_following_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl",
)
try:
with open(filename, "w", encoding="utf-8") as f:
for rollout_dict in self.rollouts_to_save_buffer:
json.dump(rollout_dict, f, ensure_ascii=False)
f.write("\n")
print(f"Successfully saved {buffer_size} rollout groups to {filename}")
self.save_file_batch_num += 1
self.rollouts_to_save_buffer.clear()
except Exception as e:
print(f"Error saving rollouts to {filename}: {e}")
async def _save_failed_rollouts_to_jsonl(self):
"""Saves the buffered failed rollouts to a JSONL file for debugging."""
if not self.failed_rollouts_to_save_buffer:
print("Warning: _save_failed_rollouts_to_jsonl called but buffer is empty!")
return
buffer_size = len(self.failed_rollouts_to_save_buffer)
print(f"Starting save of {buffer_size} failed rollout groups to JSONL file...")
try:
if not os.path.exists(self.datadumps_dir):
os.makedirs(self.datadumps_dir)
print(f"Created directory: {self.datadumps_dir}")
except Exception as e:
print(f"Error creating directory {self.datadumps_dir}: {e}")
return
filename = os.path.join(
self.datadumps_dir,
f"instruction_following_failed_rollouts_{self.run_uuid}_{self.failed_save_file_batch_num:04d}.jsonl",
)
try:
with open(filename, "w", encoding="utf-8") as f:
for rollout_dict in self.failed_rollouts_to_save_buffer:
json.dump(rollout_dict, f, ensure_ascii=False)
f.write("\n")
print(
f"Successfully saved {buffer_size} failed rollout groups to {filename}"
)
self.failed_save_file_batch_num += 1
self.failed_rollouts_to_save_buffer.clear()
except Exception as e:
print(f"Error saving failed rollouts to {filename}: {e}")
# ----- IFEval Verifier Functions and Map -----
# adapted from https://github.com/allenai/open-instruct/blob/main/scripts/eval_constraints/if_functions.py