atropos/environments/reasoning_gym_environment/reasoning_gym_environment.py
2025-06-06 15:24:00 -07:00

1316 lines
54 KiB
Python

import json
import logging
import os
import pkgutil
import random
import re
import sys
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
Item,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
# Add the local reasoning-gym submodule to Python's path
# This allows `import reasoning_gym` to find the local submodule
_SUBMODULE_DIR = os.path.abspath(
os.path.join(os.path.dirname(__file__), "reasoning-gym")
)
if _SUBMODULE_DIR not in sys.path:
sys.path.insert(0, _SUBMODULE_DIR)
# Attempt to import reasoning_gym. If not found, a warning will be issued in _get_task_names.
try:
import reasoning_gym
from reasoning_gym.utils import extract_answer
except ImportError:
reasoning_gym = None
extract_answer = None
system_prompt = (
"You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the "
"problem and deliberate with yourself via systematic reasoning processes to help come to a correct "
"solution prior to answering. You should enclose your thoughts and internal monologue inside <think> "
"</think> tags, and then provide your solution or response to the problem. After your thinking, "
"make sure to clearly provide your final answer inside <answer></answer> tags. You can provide "
"context, explanation, etc before and after your answer tags/answer, but you must provide a single "
"answer and place it inside <answer> tags. You must provide a single answer and place it inside "
"<answer> tags."
)
class ReasoningGymEnvConfig(BaseEnvConfig):
"""Extended configuration for ReasoningGymEnv with additional fields."""
dump_rollouts: bool = Field(
default=False,
description="Whether to dump successful rollouts (above threshold) to JSONL files.",
)
dump_failed_rollouts: bool = Field(
default=False,
description="Whether to dump failed rollouts (all 0 scores) to JSONL files for debugging.",
)
seed: int = Field(
default=42,
description="Random seed for reproducibility.",
)
debug_logging: bool = Field(
default=False,
description="Enable debug-level logging for more verbose output.",
)
suppress_base_env_logs: bool = Field(
default=True,
description="Suppress verbose base environment logs (like status dict updates).",
)
rollout_save_score_threshold: float = Field(
default=0.7,
description="Minimum score threshold for saving rollouts to data dumps. Only groups with at least one rollout above this threshold will be saved.", # noqa: E501
)
num_eval_samples_per_task: int = Field(
default=5,
description="Number of evaluation samples to generate per task for the test set.",
)
eval_seed: int = Field(
default=123,
description="Seed for generating fixed evaluation set to ensure reproducibility.",
)
def validate_config(self):
"""Validate configuration parameters."""
if not (0.0 <= self.rollout_save_score_threshold <= 1.0):
raise ValueError(
f"rollout_save_score_threshold must be between 0.0 and 1.0, got {self.rollout_save_score_threshold}"
)
if self.rollout_save_score_threshold == 1.0:
print(
f"Warning: rollout_save_score_threshold is {self.rollout_save_score_threshold}, which may be too strict and result in no saved rollouts." # noqa: E501
)
if self.num_eval_samples_per_task <= 0:
raise ValueError(
f"num_eval_samples_per_task must be positive, got {self.num_eval_samples_per_task}"
)
class ReasoningGymEnv(BaseEnv):
name = "reasoning_gym"
env_config_cls = ReasoningGymEnvConfig
def __init__(
self,
config: ReasoningGymEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
# Validate configuration before proceeding
config.validate_config()
super().__init__(config, server_configs, slurm, testing)
# Initialize the logger like swe_rl_env.py
self.logger = logging.getLogger(self.__class__.__name__)
if not self.logger.handlers:
# Add a basic stream handler if no handlers are configured
_handler = logging.StreamHandler()
_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
_handler.setFormatter(_formatter)
self.logger.addHandler(_handler)
# Set logging level based on config
log_level = logging.DEBUG if self.config.debug_logging else logging.INFO
self.logger.setLevel(log_level)
# Ensure the logger itself is enabled
self.logger.disabled = False
# Suppress base environment logs if requested
if self.config.suppress_base_env_logs:
# Set the base environment logger to WARNING level to suppress INFO logs
base_logger = logging.getLogger("atroposlib.envs.base")
base_logger.setLevel(logging.WARNING)
# Set max_token_len for base class compatibility
self.max_token_len = self.config.max_token_length
self.percent_correct_buffer = list()
self.eval_metrics = list()
self.rollouts_for_wandb: List[List[Tuple[str, float, str, str]]] = []
self.task_names: List[str] = []
self.test_items_with_scorers: List[Tuple[Dict[str, Any], Any]] = []
self.rng = random.Random()
self.run_uuid = str(uuid.uuid4())
self.rollouts_to_save_buffer: List[
Dict[str, Union[str, List[Dict[str, Union[List[Dict[str, str]], float]]]]]
] = []
self.processed_item_count = 0
self.datadumps_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "data_dumps"
)
self.save_file_batch_num = 0
# For saving failed rollouts (all 0 scores) for debugging
self.failed_rollouts_to_save_buffer: List[
Dict[str, Union[str, List[Dict[str, Union[List[Dict[str, str]], float]]]]]
] = []
self.failed_processed_item_count = 0
self.failed_save_file_batch_num = 0
@classmethod
def config_init(cls) -> Tuple[ReasoningGymEnvConfig, List[APIServerConfig]]:
env_config = ReasoningGymEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=250,
seed=1918,
batch_size=1024,
steps_per_eval=25,
max_token_length=1024 * 16,
inference_weight=1.0,
wandb_name="reasoning_gym_think", # Specific name for reasoning gym
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
num_rollouts_per_group_for_logging=4,
num_rollouts_to_keep=50,
dump_rollouts=False,
dump_failed_rollouts=False,
debug_logging=False,
suppress_base_env_logs=True,
rollout_save_score_threshold=0.51,
num_eval_samples_per_task=5,
eval_seed=123,
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_max_requests_at_once=32,
num_requests_for_eval=256,
),
]
return env_config, server_configs
def _get_task_names(self) -> List[str]:
"""
Gets task names from the reasoning_gym DATASETS registry.
This is more reliable than dynamic discovery via pkgutil.
Falls back to a predefined list if registry access fails.
"""
if reasoning_gym is None:
print(
"ERROR: The local reasoning-gym submodule could not be imported. "
"Ensure it is present in 'atropos/environments/reasoning_gym_environment/reasoning-gym' "
"and is a valid Python package (e.g., has __init__.py)."
)
return ["algebra/linear_1d", "arithmetic/add_or_subtract", "leg_counting"]
discovered_tasks = []
try:
# Access the DATASETS registry directly from reasoning_gym.factory
from reasoning_gym.factory import DATASETS
discovered_tasks = list(DATASETS.keys())
print(
f"Discovered {len(discovered_tasks)} tasks from DATASETS registry: {discovered_tasks[:10]}{'...' if len(discovered_tasks) > 10 else ''}" # noqa: E501
)
except Exception as e:
print(
f"WARNING: Could not access DATASETS registry: {e}. "
"Falling back to manual discovery and validation."
)
# Fallback to pkgutil discovery if registry access fails
try:
package = reasoning_gym
if not hasattr(package, "__path__"):
raise AttributeError("'reasoning_gym' is not a package")
base_module_path_parts = package.__name__.split(".")
print(
f"Attempting to discover tasks in package: {package.__name__} from path: {package.__path__}"
)
for _, modname, ispkg in pkgutil.walk_packages(
path=package.__path__,
prefix=package.__name__ + ".",
onerror=lambda name: print(
f"Error importing module during task discovery: {name}"
),
):
if not ispkg:
module_parts = modname.split(".")
if (
len(module_parts) > len(base_module_path_parts)
and module_parts[: len(base_module_path_parts)]
== base_module_path_parts
):
task_specific_parts = module_parts[
len(base_module_path_parts) :
]
# Filter out potential private/internal modules
if any(
part.startswith("_") for part in task_specific_parts
):
continue
task_name = "/".join(task_specific_parts)
discovered_tasks.append(task_name)
if discovered_tasks:
print(
f"Dynamically discovered {len(discovered_tasks)} potential task names from submodule."
)
except Exception as e2:
print(
f"WARNING: Pkgutil discovery also failed: {e2}. Using fallback list."
)
if not discovered_tasks:
print(
"WARNING: All discovery methods failed. Using fallback list of known tasks."
)
# Complete fallback list with all available reasoning-gym tasks
fallback_tasks = [
"ab",
"acre",
"advanced_geometry",
"aiw",
"arc_1d",
"arc_agi",
"base_conversion",
"basic_arithmetic",
"bf",
"binary_alternation",
"binary_matrix",
"bitwise_arithmetic",
"boxnet",
"caesar_cipher",
"calendar_arithmetic",
"chain_sum",
"circuit_logic",
"codeio",
"color_cube_rotation",
"complex_arithmetic",
"composite",
"count_bits",
"count_primes",
"countdown",
"course_schedule",
"cryptarithm",
"decimal_arithmetic",
"decimal_chain_sum",
"dice",
"emoji_mystery",
"family_relationships",
"figlet_font",
"fraction_simplification",
"futoshiki",
"game_of_life",
"game_of_life_halting",
"gcd",
"graph_color",
"group_anagrams",
"gsm_symbolic",
"intermediate_integration",
"isomorphic_strings",
"jugs",
"knight_swap",
"knights_knaves",
"largest_island",
"lcm",
"leg_counting",
"letter_counting",
"letter_jumble",
"list_functions",
"mahjong_puzzle",
"manipulate_matrix",
"maze",
"mini_sudoku",
"modulo_grid",
"n_queens",
"needle_haystack",
"number_filtering",
"number_format",
"number_sequence",
"number_sorting",
"palindrome_generation",
"palindrome_partitioning",
"polynomial_equations",
"polynomial_multiplication",
"pool_matrix",
"power_function",
"prime_factorization",
"products",
"propositional_logic",
"puzzle24",
"quantum_lock",
"ransom_note",
"rearc",
"rectangle_count",
"rotate_matrix",
"rotten_oranges",
"rubiks_cube",
"rush_hour",
"self_reference",
"sentence_reordering",
"shortest_path",
"simple_equations",
"simple_geometry",
"simple_integration",
"sokoban",
"spell_backward",
"spiral_matrix",
"string_insertion",
"string_manipulation",
"string_splitting",
"string_synthesis",
"sudoku",
"syllogism",
"time_intervals",
"tower_of_hanoi",
"tsumego",
"word_ladder",
"word_sequence_reversal",
"word_sorting",
"zebra_puzzles",
]
return self._validate_discovered_tasks(
fallback_tasks, reasoning_gym if reasoning_gym else None
)
# Validate all discovered tasks
return self._validate_discovered_tasks(discovered_tasks, reasoning_gym)
def _validate_discovered_tasks(
self, task_names_to_validate: List[str], rg_package_or_none: Any
) -> List[str]:
"""
Validates a list of task names by attempting to create a dataset for each.
Args:
task_names_to_validate: List of task names (e.g., "domain/task").
rg_package_or_none: The imported reasoning_gym package, or None if import failed.
Returns:
A list of task names that were successfully validated.
"""
if rg_package_or_none is None:
print(
"Validation SKIPPED: reasoning_gym package not available for validation."
)
return []
valid_tasks = []
print(
f"Validating {len(task_names_to_validate)} discovered/fallback task names..."
)
for task_name in task_names_to_validate:
try:
# Use the potentially imported reasoning_gym directly for create_dataset
_ = rg_package_or_none.create_dataset(task_name, size=1, seed=0)
valid_tasks.append(task_name)
except Exception as e:
print(
f"Note: Task '{task_name}' could not be loaded (validation failed): {type(e).__name__} - {e}"
) # noqa: E501
pass
if not valid_tasks and task_names_to_validate:
print(
"WARNING: Validation of discovered/fallback tasks failed for all. This might indicate a systematic issue." # noqa: E501
)
# If validation fails for all, it's safer to return an empty list or a minimal known-good set.
# However, if task_names_to_validate was non-empty, returning it raw might be a last resort if validation itself is flawed. # noqa: E501
# For safety, let's prefer an empty list if validation fails completely.
print("No tasks passed validation.")
return []
if not valid_tasks:
print(
"CRITICAL WARNING: No valid reasoning-gym tasks could be loaded. Environment may not function."
)
# Return an absolute minimal, known-good task if everything else fails.
# This specific task must exist and be loadable in reasoning_gym.
try:
_ = rg_package_or_none.create_dataset("leg_counting", size=1, seed=0)
print("Falling back to absolute minimal task: 'leg_counting'")
return ["leg_counting"]
except Exception:
print("Absolute fallback 'leg_counting' also failed to load.")
return []
print(f"Validated {len(valid_tasks)} tasks for use.")
return valid_tasks
async def setup(self):
# The reasoning_gym import is now handled at the top with sys.path modification.
if reasoning_gym is None:
raise ImportError(
"reasoning-gym library could not be imported from the local submodule. "
"This environment cannot function. Check submodule presence and integrity."
)
self.logger.info("Setting up ReasoningGym environment...")
self.task_names = (
self._get_task_names()
) # _get_task_names now uses self._validate_discovered_tasks
if not self.task_names:
raise ValueError(
"No reasoning_gym tasks could be loaded. Environment setup failed."
)
self.logger.info(
f"ReasoningGymEnv: Initialized with {len(self.task_names)} tasks."
)
self.logger.info(f"Sample tasks: {self.task_names[:5]}")
# Seed for main RNG used in get_next_item
# The seed for reasoning_gym dataset creation will be self.iter
self.rng.seed(self.config.seed)
self.iter = 0
# Create a fixed test set for evaluation
self.logger.info("Generating fixed test set for evaluation...")
eval_tasks_sample = self.rng.sample(
self.task_names, min(len(self.task_names), 20)
) # Sample 20 tasks for eval
for task_name in tqdm_asyncio(eval_tasks_sample, desc="Creating eval dataset"):
try:
# Each task gets its own dataset instance for evaluation
# Using a fixed seed for reproducibility of the test set
dataset = reasoning_gym.create_dataset(
task_name,
size=self.config.num_eval_samples_per_task,
seed=self.config.eval_seed,
)
for item in dataset:
self.test_items_with_scorers.append((item, dataset))
except Exception as e:
self.logger.warning(
f"Could not create eval dataset for task '{task_name}': {e}"
)
if not self.test_items_with_scorers:
self.logger.warning(
"No evaluation items could be generated. Evaluation might be skipped or fail."
)
else:
self.logger.info(
f"Generated {len(self.test_items_with_scorers)} items for the evaluation test set."
)
self.logger.info(
"ReasoningGym environment setup complete. Ready to start training!"
)
self.logger.info(
f"Configuration: group_size={self.config.group_size}, max_token_length={self.config.max_token_length}, steps_per_eval={self.config.steps_per_eval}" # noqa: E501
)
if self.config.dump_rollouts:
self.logger.info(
f"Data dumping enabled with score threshold: {self.config.rollout_save_score_threshold}"
)
self.logger.info(
"Using strict <answer> tag enforcement: models must use exactly one <answer> tag or receive 0 score"
)
self.logger.info(
"Using dual-format scoring for valid answers: trying both raw answers and <answer>-tagged answers, using higher score" # noqa: E501
)
async def get_next_item(self) -> Optional[Item]:
"""
Get the next training item by randomly selecting a reasoning_gym task and generating a sample.
Returns:
A tuple: (prompt_messages, reasoning_gym_item, reasoning_gym_dataset_object)
prompt_messages: Formatted for the language model.
reasoning_gym_item: The raw item from reasoning_gym (dict with 'question', 'answer', 'metadata').
reasoning_gym_dataset_object: The dataset object from which the item was generated (for scoring).
"""
if not self.task_names:
return None # Should not happen if setup is correct
selected_task_name = self.rng.choice(self.task_names)
try:
# Create a new dataset instance for each item to ensure variety if tasks have internal state
# Use self.iter for seed to get different questions over time
current_seed = self.config.seed + self.iter # Vary seed per item
dataset_obj = reasoning_gym.create_dataset(
selected_task_name, size=1, seed=current_seed
)
rg_item = next(iter(dataset_obj)) # Get the single item
# Log task selection every 10 items to avoid spam
if self.iter % 10 == 0:
self.logger.info(
f"Selected task: {selected_task_name} (iteration {self.iter})"
)
except Exception as e:
self.logger.warning(
f"Error generating item for task {selected_task_name} with seed {current_seed}: {e}"
)
return None # Skip this item if generation fails
self.iter += 1
question_text = rg_item["question"]
# Construct prompt messages
prompt_messages = [
frozenset({"role": "system", "content": system_prompt}.items()),
frozenset({"role": "user", "content": question_text}.items()),
]
# The 'answer' here is the data needed for scoring later
return (tuple(prompt_messages), rg_item, dataset_obj)
def _extract_final_answer(self, text: str) -> Optional[str]:
"""
Extracts the content from <answer> tags using reasoning-gym's extract_answer function.
Enforces strict compliance with answer tag instructions.
Returns None if model doesn't follow instructions properly (no answer tags or multiple answer tags outside think).
""" # noqa: E501
if extract_answer is None:
# If reasoning-gym not available, we can't enforce the format
if self.config.debug_logging:
self.logger.debug(
"reasoning-gym extract_answer not available, cannot enforce answer tag format"
)
return None
# Check for multiple <answer> tags outside of <think> blocks
# First, remove all <think>...</think> content to check only the "outside" content
think_pattern = r"<think>.*?</think>"
text_outside_think = re.sub(
think_pattern, "", text, flags=re.DOTALL | re.IGNORECASE
)
# Count <answer> tags in the content outside <think> blocks
answer_pattern = r"<answer>.*?</answer>"
answer_matches_outside_think = re.findall(
answer_pattern, text_outside_think, flags=re.DOTALL | re.IGNORECASE
)
if len(answer_matches_outside_think) > 1:
if self.config.debug_logging:
self.logger.debug(
f"Model provided {len(answer_matches_outside_think)} answer tags outside think blocks - failing for not following single answer instruction" # noqa: E501
)
return None
# Try to extract from <answer> tags using reasoning-gym's function
answer_content = extract_answer(text, tag_name="answer", strip=True)
if answer_content is not None:
if self.config.debug_logging:
self.logger.debug(
f"Successfully extracted answer from <answer> tags: '{answer_content[:100]}{'...' if len(answer_content) > 100 else ''}'" # noqa: E501
)
return answer_content
# No valid answer tags found - model failed to follow instructions
if self.config.debug_logging:
self.logger.debug(
"No <answer> tags found - model failed to follow answer format instructions"
)
return None
def _score_answer_with_both_formats(
self, model_answer: str, rg_item: dict, dataset_obj: Any
) -> float:
"""
Score the answer using both formats (with and without <answer> tags) and return the higher score.
This handles verifiers that expect different formats.
"""
# Format 1: Just the answer content
try:
score1 = dataset_obj.score_answer(answer=model_answer, entry=rg_item)
score1 = max(0.0, min(1.0, float(score1)))
except Exception as e:
task_name = rg_item.get("metadata", {}).get("source_dataset", "unknown")
self.logger.debug(
f"Error scoring answer format 1 for task {task_name}: {e}"
)
score1 = 0.0
# Format 2: Answer wrapped in <answer> tags
answer_with_tags = f"<answer>{model_answer}</answer>"
try:
score2 = dataset_obj.score_answer(answer=answer_with_tags, entry=rg_item)
score2 = max(0.0, min(1.0, float(score2)))
except Exception as e:
task_name = rg_item.get("metadata", {}).get("source_dataset", "unknown")
self.logger.debug(
f"Error scoring answer format 2 for task {task_name}: {e}"
)
score2 = 0.0
# Return the higher score
final_score = max(score1, score2)
# Log which format worked better (only in debug mode)
if self.config.debug_logging and score1 != score2:
task_name = rg_item.get("metadata", {}).get("source_dataset", "unknown")
if score1 > score2:
self.logger.debug(
f"Task {task_name}: Raw answer format scored higher ({score1:.3f} vs {score2:.3f})"
)
else:
self.logger.debug(
f"Task {task_name}: Tagged answer format scored higher ({score2:.3f} vs {score1:.3f})"
)
return final_score
async def score(
self,
rollout_group_data: List[Tuple[Tuple[Dict[str, str]], Dict[str, Any], Any]],
) -> Optional[ScoredDataGroup]:
"""
Scores a group of rollouts using reasoning_gym's score_answer method.
Args:
rollout_group_data: A list of tuples, where each tuple contains:
- trajectory_messages: The full conversation history for the rollout.
- rg_item: The original reasoning_gym item (contains 'question', 'answer', 'metadata').
- dataset_obj: The reasoning_gym dataset object used to generate and score the item.
Returns:
ScoredDataGroup with scores between 0.0 and 1.0, or None if no valid items
"""
scores_container = ScoredDataGroup()
scores_container["tokens"] = list()
scores_container["masks"] = list()
scores_container["scores"] = list()
if not rollout_group_data:
return None
rg_item_for_group = rollout_group_data[0][1]
dataset_obj_for_group = rollout_group_data[0][2]
# Shuffle to avoid bias in selection
random.shuffle(rollout_group_data)
for trajectory_messages, _, _ in rollout_group_data:
model_full_response = trajectory_messages[-1]["content"]
# Extract the part of the response that should be the answer
model_answer_to_score = self._extract_final_answer(model_full_response)
# If extraction failed (model didn't follow instructions), give 0 score
if model_answer_to_score is None:
reward_0_to_1 = 0.0
if self.config.debug_logging:
task_name = rg_item_for_group.get("metadata", {}).get(
"source_dataset", "unknown"
)
self.logger.debug(
f"Task {task_name}: Giving 0 score due to failed answer extraction (didn't follow format)"
)
else:
# Use our dual-format scoring method that tries both raw answer and tagged answer
reward_0_to_1 = self._score_answer_with_both_formats(
model_answer_to_score, rg_item_for_group, dataset_obj_for_group
)
# Tokenize the conversation for learning
out_dict = tokenize_for_trainer(self.tokenizer, trajectory_messages)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
# Remove examples with insufficient context
if len([1 for i in masks if i != -100]) < 10:
continue
scores_container["tokens"].append(tokens)
scores_container["masks"].append(masks)
scores_container["scores"].append(reward_0_to_1)
if len(scores_container["tokens"]) >= self.config.group_size:
break
if not scores_container["tokens"]:
self.logger.warning(
"No valid items were scored in this group - all items had insufficient context or failed scoring"
)
return None
# Record success rate metrics (convert to binary for percent_correct tracking)
for score_val in scores_container["scores"]:
self.percent_correct_buffer.append(1.0 if score_val >= 0.5 else 0.0)
# Calculate and log average score for the current group
current_scores = scores_container.get("scores", [])
if current_scores:
average_score = sum(current_scores) / len(current_scores)
task_name = rg_item_for_group.get("metadata", {}).get(
"source_dataset", "unknown_task"
)
log_message_main = (
f"Task: {task_name} | Group average score: {average_score:.4f}"
)
if all(s >= 0.5 for s in current_scores):
self.logger.info(f"{log_message_main} (All correct in this group!)")
elif all(s == 0.0 for s in current_scores):
self.logger.info(f"{log_message_main} (All failed - no valid answers!)")
elif all(s < 0.5 for s in current_scores):
self.logger.info(
f"{log_message_main} (All incorrect but some partial credit!)"
)
else:
self.logger.info(log_message_main)
# Apply length penalty if all responses are correct (score >= 0.5)
if all(s >= 0.5 for s in scores_container["scores"]):
# Calculate token lengths
token_lengths = [len(token) for token in scores_container["tokens"]]
if max(token_lengths) == 0:
return None
# Get max allowed token length from config
max_allowed_length = self.config.max_token_length
# Set threshold at 75% of max_token_length
length_threshold = max_allowed_length * 0.75
# Apply modified length penalty with threshold
new_scores = []
penalties_applied = 0
for i, length in enumerate(token_lengths):
original_score = scores_container["scores"][i]
if length <= length_threshold:
new_scores.append(original_score)
else:
# Calculate how far we are between threshold and max as a percentage
percentage_of_range = (length - length_threshold) / (
max_allowed_length - length_threshold
)
percentage_of_range = min(percentage_of_range, 1.0)
# Apply linear penalty scaling from original_score down to 0.0
penalized_score = original_score * (1.0 - percentage_of_range)
new_scores.append(penalized_score)
penalties_applied += 1
if penalties_applied > 0:
avg_length = sum(token_lengths) / len(token_lengths)
self.logger.debug(
f"Applied length penalty to {penalties_applied}/{len(token_lengths)} responses (avg length: {avg_length:.0f}, threshold: {length_threshold:.0f})" # noqa: E501
)
scores_container["scores"] = new_scores
# Check if all scores are the same (no learning signal)
if all(
scores_container["scores"][0] == score
for score in scores_container["scores"]
):
self.logger.debug(
f"All scores in group are identical ({scores_container['scores'][0]:.4f}) - no learning signal, skipping group" # noqa: E501
)
# Before returning None, check if this is a completely failed group (all 0.0 scores) for debugging
if self.config.dump_failed_rollouts and all(
score == 0.0 for score in scores_container["scores"]
):
self.logger.debug(
"Saving failed group (all 0 scores) for debugging analysis"
)
await self._save_failed_group_for_debugging(
rollout_group_data, scores_container
)
return None
return scores_container
async def _save_failed_group_for_debugging(
self, rollout_group_data, scores_container
):
"""Helper method to save failed groups (all 0 scores) for debugging analysis."""
failed_rollouts_with_scores_to_save = []
# Build the failed rollouts data structure
for i, (trajectory_messages, rg_item, dataset_obj) in enumerate(
rollout_group_data
):
if i < len(scores_container["scores"]):
score_for_rollout = scores_container["scores"][i]
failed_rollouts_with_scores_to_save.append(
{
"conversation": trajectory_messages, # Full conversation history
"score": score_for_rollout,
}
)
if failed_rollouts_with_scores_to_save:
# Extract item info for logging - get from first rollout
_, rg_item, _ = rollout_group_data[0]
item_id = rg_item.get("metadata", {}).get("source_dataset", "unknown_task")
failed_item_data_to_save = {
"item_id": item_id,
"rollouts": failed_rollouts_with_scores_to_save,
}
self.failed_rollouts_to_save_buffer.append(failed_item_data_to_save)
self.failed_processed_item_count += 1
# Calculate progress toward next failed save
failed_batch_progress = (
self.failed_processed_item_count % 50
) # Save failed every 50 items
if failed_batch_progress == 0:
failed_batch_progress = (
50 # Show 50/50 instead of 0/50 when we hit the threshold
)
# Log progress every 10 failed items or when we hit the save threshold
if failed_batch_progress % 10 == 0 or failed_batch_progress == 50:
self.logger.info(
f"Failed rollouts progress: {failed_batch_progress}/50 items buffered "
f"(Total failed processed: {self.failed_processed_item_count}, Failed buffer size: {len(self.failed_rollouts_to_save_buffer)})" # noqa: E501
)
# Check if it's time to save a batch of failed rollouts (every 50 instead of 100)
if (
self.config.dump_failed_rollouts
and self.failed_processed_item_count % 50 == 0
and self.failed_processed_item_count > 0
):
failed_log_msg = (
f"Reached {self.failed_processed_item_count} failed items. "
f"Triggering save for {len(self.failed_rollouts_to_save_buffer)} failed items "
f"(each with multiple failed rollouts)."
)
self.logger.info(failed_log_msg)
await self._save_failed_rollouts_to_jsonl()
async def collect_trajectories(
self, item: Item
) -> Tuple[Optional[ScoredDataGroup], List]:
"""
Generate model responses for the given item and then score them.
Args:
item: A tuple from get_next_item: (prompt_messages, rg_item, dataset_obj)
"""
prompt_messages_tuple, rg_item, dataset_obj = item
# Apply chat template to convert messages to a single string
# Ensure prompt_messages_tuple is correctly formatted list of dicts
messages_for_template = [dict(msg_fset) for msg_fset in prompt_messages_tuple]
prompt_str = self.tokenizer.apply_chat_template(
messages_for_template, add_generation_prompt=True, tokenize=False
)
# Calculate max_tokens like tool_calling_server
prompt_tokens = len(self.tokenizer.encode(prompt_str))
max_tokens = min(1024 * 15, self.config.max_token_length - prompt_tokens)
completions = await self.server.completion(
prompt=prompt_str,
n=self.config.group_size,
max_tokens=max_tokens,
temperature=0.8,
)
to_score_list = []
for choice in completions.choices:
self.completion_lengths.append(len(choice.text))
# Create full trajectory messages for this choice
current_trajectory_messages = list(messages_for_template)
current_trajectory_messages.append(
{"role": "assistant", "content": choice.text}
)
to_score_list.append(
(tuple(current_trajectory_messages), rg_item, dataset_obj)
)
scored_data_group = await self.score(to_score_list)
# If rollouts were generated and scored, and data dumping is enabled, prepare them for saving
if scored_data_group and self.config.dump_rollouts:
# Only save groups that have at least one rollout with score > threshold
group_scores = scored_data_group.get("scores", [])
threshold = self.config.rollout_save_score_threshold
if any(score > threshold for score in group_scores):
self.logger.debug(
f"Saving group with scores: {[f'{s:.3f}' for s in group_scores]} (has high-quality rollout, threshold: {threshold})" # noqa: E501
)
rollouts_with_scores_to_save = []
num_scored_rollouts = len(group_scores)
for i in range(num_scored_rollouts):
conversation_messages = to_score_list[i][0]
score_for_rollout = group_scores[i]
rollouts_with_scores_to_save.append(
{
"conversation": conversation_messages,
"score": score_for_rollout,
}
)
if rollouts_with_scores_to_save:
# Extract item info for logging
_, rg_item, _ = item
item_id = rg_item.get("metadata", {}).get(
"source_dataset", "unknown_task"
)
item_data_to_save = {
"item_id": item_id,
"rollouts": rollouts_with_scores_to_save,
}
self.rollouts_to_save_buffer.append(item_data_to_save)
self.processed_item_count += 1
# Calculate progress toward next save
current_batch_progress = self.processed_item_count % 100
if current_batch_progress == 0:
current_batch_progress = 100
# Log progress every 10 items or when we hit the save threshold
if current_batch_progress % 10 == 0 or current_batch_progress == 100:
self.logger.info(
f"Data dump progress: {current_batch_progress}/100 items buffered "
f"(Total processed: {self.processed_item_count}, Buffer size: {len(self.rollouts_to_save_buffer)})" # noqa: E501
)
# Check if it's time to save a batch of rollouts
if (
self.config.dump_rollouts
and self.processed_item_count % 100 == 0
and self.processed_item_count > 0
):
log_msg = (
f"Reached {self.processed_item_count} processed items. "
f"Triggering save for {len(self.rollouts_to_save_buffer)} items "
f"(each with multiple scored rollouts)."
)
self.logger.info(log_msg)
await self._save_rollouts_to_jsonl()
else:
max_score = max(group_scores) if group_scores else 0.0
self.logger.debug(
f"Skipping group save - no high-quality rollouts (max score: {max_score:.3f}, threshold: {threshold})" # noqa: E501
)
to_backlog = []
return scored_data_group, to_backlog
async def _save_rollouts_to_jsonl(self):
"""Saves the buffered rollouts to a JSONL file in the datadumps directory."""
if not self.rollouts_to_save_buffer:
self.logger.info("No rollouts in buffer to save.")
return
try:
if not os.path.exists(self.datadumps_dir):
os.makedirs(self.datadumps_dir)
self.logger.info(f"Created directory: {self.datadumps_dir}")
except OSError as e:
self.logger.error(f"Error creating directory {self.datadumps_dir}: {e}")
return
file_path = os.path.join(
self.datadumps_dir,
f"reasoning_gym_environment_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl",
)
try:
with open(file_path, "w") as f:
for rollout_dict in self.rollouts_to_save_buffer:
json.dump(rollout_dict, f)
f.write("\n")
self.logger.info(
f"Successfully saved {len(self.rollouts_to_save_buffer)} rollouts to {file_path}"
)
self.rollouts_to_save_buffer.clear()
self.save_file_batch_num += 1
except IOError as e:
self.logger.error(f"Error writing rollouts to {file_path}: {e}")
except Exception as e:
self.logger.error(
f"An unexpected error occurred while saving rollouts to {file_path}: {e}"
)
async def _save_failed_rollouts_to_jsonl(self):
"""Saves the buffered failed rollouts (all 0 scores) to a JSONL file for debugging."""
if not self.failed_rollouts_to_save_buffer:
self.logger.info("No failed rollouts in buffer to save.")
return
try:
if not os.path.exists(self.datadumps_dir):
os.makedirs(self.datadumps_dir)
self.logger.info(f"Created directory: {self.datadumps_dir}")
except OSError as e:
self.logger.error(f"Error creating directory {self.datadumps_dir}: {e}")
return
file_path = os.path.join(
self.datadumps_dir,
f"reasoning_gym_environment_FAILED_rollouts_{self.run_uuid}_{self.failed_save_file_batch_num:04d}.jsonl",
)
try:
with open(file_path, "w") as f:
for rollout_dict in self.failed_rollouts_to_save_buffer:
json.dump(rollout_dict, f)
f.write("\n")
self.logger.info(
f"Successfully saved {len(self.failed_rollouts_to_save_buffer)} FAILED rollouts to {file_path}"
)
self.failed_rollouts_to_save_buffer.clear()
self.failed_save_file_batch_num += 1
except IOError as e:
self.logger.error(f"Error writing failed rollouts to {file_path}: {e}")
except Exception as e:
self.logger.error(
f"An unexpected error occurred while saving failed rollouts to {file_path}: {e}"
)
async def rollout_and_score_eval(
self, test_data_tuple: Tuple[Dict[str, Any], Any]
) -> float:
"""
Performs a rollout for a single evaluation item and scores it.
Args:
test_data_tuple: A tuple (rg_item, dataset_obj) from self.test_items_with_scorers.
Returns:
Score (1.0 for correct, 0.0 for incorrect/error).
"""
rg_item, dataset_obj = test_data_tuple
question_text = rg_item["question"]
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question_text},
]
prompt_str = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
# Calculate max_tokens like tool_calling_server
prompt_tokens = len(self.tokenizer.encode(prompt_str))
max_tokens = min(1024 * 15, self.config.max_token_length - prompt_tokens)
completion = await self.server.completion(
prompt=prompt_str,
n=1,
max_tokens=max_tokens,
temperature=0.1,
split="eval",
)
model_full_response = completion.choices[0].text
model_answer_to_score = self._extract_final_answer(model_full_response)
# If extraction failed (model didn't follow instructions), give 0 score
if model_answer_to_score is None:
if self.config.debug_logging:
task_name = rg_item.get("metadata", {}).get("source_dataset", "unknown")
self.logger.debug(
f"Eval - Task {task_name}: Giving 0 score due to failed answer extraction (didn't follow format)"
)
return 0.0
# Use our dual-format scoring method for evaluation as well
return self._score_answer_with_both_formats(
model_answer_to_score, rg_item, dataset_obj
)
async def evaluate(self, *args, **kwargs):
self.logger.info("Starting evaluation...")
if not self.test_items_with_scorers:
self.logger.warning("No test items available for evaluation. Skipping.")
self.eval_metrics.append(("eval/percent_correct", 0.0))
return
eval_tasks = [
self.rollout_and_score_eval(item_tuple)
for item_tuple in self.test_items_with_scorers
]
self.logger.info(
f"Starting evaluation on {len(self.test_items_with_scorers)} items..."
)
scores = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating")
if not scores:
percent_correct = 0.0
else:
percent_correct = sum(scores) / len(scores)
self.eval_metrics.append(("eval/percent_correct", percent_correct))
self.logger.info(f"Evaluation finished. Percent correct: {percent_correct:.4f}")
async def add_rollouts_for_wandb(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
):
if item is None or scored_data is None or not scored_data.get("tokens"):
return
_, rg_item, _ = item
expected_answer = str(rg_item.get("answer", "N/A"))
task_name = str(
rg_item.get("metadata", {}).get(
"source_dataset", rg_item.get("task_name", "unknown_task")
)
)
# save rollout to trajectory
num_keep = self.config.num_rollouts_per_group_for_logging
if num_keep == -1:
num_keep = self.config.group_size
# Make sure there's data to log
num_keep = min(num_keep, len(scored_data["tokens"]))
if num_keep == 0:
return
current_rollouts = []
for i in range(num_keep):
# Ensure tokens and scores have the same length
if i < len(scored_data["tokens"]) and i < len(scored_data["scores"]):
# Decode the full trajectory including prompt and model response
full_text = self.tokenizer.decode(
scored_data["tokens"][i], skip_special_tokens=True
)
score_val = scored_data["scores"][i]
current_rollouts.append(
(full_text, score_val, expected_answer, task_name)
)
else:
print(
f"Warning: Mismatch in lengths of tokens/scores for wandb logging at index {i}."
)
self.rollouts_for_wandb.append(current_rollouts)
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
async def create_rollout_table(self, wandb_metrics: Dict) -> Dict:
if len(self.rollouts_for_wandb) > 0:
table = wandb.Table(
columns=["full_text", "score", "expected_answer", "task_name"]
)
for group_rollouts in self.rollouts_for_wandb:
for rollout_tuple in group_rollouts:
# Ensure rollout_tuple has exactly 4 elements as defined in columns
if len(rollout_tuple) == 4:
table.add_data(*rollout_tuple)
else:
print(
f"Warning: Skipping malformed rollout_tuple for wandb table: {rollout_tuple}"
)
wandb_metrics["train/rollouts"] = table
self.rollouts_for_wandb = []
return wandb_metrics
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""
Log to wandb with comprehensive metrics.
"""
if wandb_metrics is None:
wandb_metrics = dict()
# Try to calculate percent_correct, skip if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
pass
self.percent_correct_buffer = list()
# Add eval metrics
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
# Add rollout table
wandb_metrics = await self.create_rollout_table(wandb_metrics)
# Call superclass wandb_log
await super().wandb_log(wandb_metrics)
def save_checkpoint(self, step, data=None):
"""Save checkpoint including current iteration number, completion lengths, and data dumping state."""
if data is None:
data = {}
data["iter"] = self.iter
data["rng_state"] = self.rng.getstate()
data["completion_lengths"] = self.completion_lengths
data["processed_item_count"] = self.processed_item_count
data["save_file_batch_num"] = self.save_file_batch_num
data["failed_processed_item_count"] = self.failed_processed_item_count
data["failed_save_file_batch_num"] = self.failed_save_file_batch_num
super().save_checkpoint(step, data)
def load_checkpoint(self):
"""Load checkpoint including iteration number, completion lengths, and data dumping state."""
# Call the base class method first to load the data
super().load_checkpoint()
# The base class loads data into attributes, so we can access them directly
# if they were saved in save_checkpoint
if hasattr(self, "iter"):
# Data was loaded successfully, no need to do anything else
pass
async def close(self):
"""Clean up and save any remaining rollouts before exiting."""
self.logger.info(
"Closing ReasoningGymEnv. Attempting to save any remaining rollouts..."
)
if (
self.config.dump_rollouts and self.rollouts_to_save_buffer
): # Check if there's anything to save
self.logger.info(
f"Found {len(self.rollouts_to_save_buffer)} rollouts in buffer. Saving now."
)
await self._save_rollouts_to_jsonl()
else:
self.logger.info("No rollouts in buffer to save upon closing.")
# Also save any remaining failed rollouts
if self.config.dump_failed_rollouts and self.failed_rollouts_to_save_buffer:
self.logger.info(
f"Found {len(self.failed_rollouts_to_save_buffer)} failed rollouts in buffer. Saving now."
)
await self._save_failed_rollouts_to_jsonl()
else:
self.logger.info("No failed rollouts in buffer to save upon closing.")
# Call the superclass's close method if it exists
if hasattr(super(), "close"):
await super().close()
self.logger.info("ReasoningGymEnv closed.")
if __name__ == "__main__":
ReasoningGymEnv.cli()