diff --git a/.gitmodules b/.gitmodules index 3c941b35..ef77a7f2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "environments/intern_bootcamp/internbootcamp_lib"] path = environments/intern_bootcamp/internbootcamp_lib url = https://github.com/InternLM/InternBootcamp.git +[submodule "environments/reasoning_gym_environment/reasoning-gym"] + path = environments/reasoning_gym_environment/reasoning-gym + url = https://github.com/open-thought/reasoning-gym diff --git a/environments/reasoning_gym_environment/README.md b/environments/reasoning_gym_environment/README.md new file mode 100644 index 00000000..0eeefca0 --- /dev/null +++ b/environments/reasoning_gym_environment/README.md @@ -0,0 +1,216 @@ +# ReasoningGym Environment + +A reinforcement learning environment for training language models on diverse reasoning tasks using the [reasoning-gym](https://github.com/reasoning-gym/reasoning-gym) library. + +## Overview + +The ReasoningGym environment provides access to 100+ reasoning tasks spanning mathematics, logic, programming, and more. It supports: + +- **Diverse Task Types**: Arithmetic, algebra, logic puzzles, programming challenges, and more +- **Strict Answer Format Enforcement**: Models must use `` tags or receive 0 score +- **Dual-Format Scoring**: Tries both raw answers and tagged answers, using the higher score +- **Data Collection**: Optional rollout dumping for successful and failed attempts +- **Comprehensive Logging**: Detailed progress tracking and debugging information + +## Features + +### Task Diversity +- 100+ tasks from reasoning-gym including GSM Symbolic, ARC, Sudoku, and more +- Automatic task discovery from the reasoning-gym registry +- Fallback to comprehensive task list if discovery fails + +### Scoring System +- **Binary Tasks**: 0.0 or 1.0 (most tasks) +- **Partial Credit**: Some tasks like GSM Symbolic give 0.01 for wrong but valid numbers +- **Continuous Scoring**: Word Ladder, Sentence Reordering use percentage-based scoring +- **Length Penalty**: Applied to overly long responses when all are correct + +### Data Collection +- **Successful Rollouts**: Save groups with scores above configurable threshold +- **Failed Rollouts**: Save completely failed groups (all 0 scores) for debugging +- **Progress Tracking**: Shows buffer progress toward save thresholds +- **JSONL Format**: Easy to process saved data + +## Configuration + +### Key Parameters + +```python +class ReasoningGymEnvConfig(BaseEnvConfig): + dump_rollouts: bool = False # Save successful rollouts + dump_failed_rollouts: bool = False # Save failed rollouts for debugging + rollout_save_score_threshold: float = 0.7 # Minimum score to save group + debug_logging: bool = False # Enable verbose logging + suppress_base_env_logs: bool = True # Hide base environment logs + seed: int = 42 # Random seed for reproducibility +``` + +### Example Configuration + +```python +env_config = ReasoningGymEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=16, + max_token_length=1024 * 16, + dump_rollouts=True, + dump_failed_rollouts=True, + rollout_save_score_threshold=0.7, + debug_logging=True, +) +``` + +## Setup + +### Prerequisites + +1. **reasoning-gym submodule**: Clone the reasoning-gym repository as a submodule: + ```bash + cd atropos/environments/reasoning_gym_environment/ + git submodule add https://github.com/reasoning-gym/reasoning-gym.git reasoning-gym + ``` + +2. **Dependencies**: Install requirements: + ```bash + pip install -r requirements.txt + ``` + +### Directory Structure +``` +reasoning_gym_environment/ +├── reasoning_gym_environment.py # Main environment code +├── reasoning-gym/ # Git submodule +├── data_dumps/ # Generated rollout data (created automatically) +├── requirements.txt # Dependencies +└── README.md # This file +``` + +## Usage + +### Basic Training + +```python +from atropos.environments.reasoning_gym_environment import ReasoningGymEnv + +# Initialize environment +env_config, server_configs = ReasoningGymEnv.config_init() +env = ReasoningGymEnv(env_config, server_configs) + +# Setup and run +await env.setup() +# Training loop handled by atropos framework +``` + +### Command Line + +```bash +python reasoning_gym_environment.py +``` + +## System Prompt + +The environment uses a structured reasoning prompt that encourages models to: + +1. Use `` tags for internal reasoning +2. Provide final answers in `` tags +3. Follow strict format requirements + +Example model response: +``` + +This is a math problem. Let me work through it step by step. +2 + 3 = 5 + + +Looking at this problem, I need to add 2 and 3. + +5 +``` + +## Data Output + +### Successful Rollouts +Saved to `data_dumps/reasoning_gym_environment_rollouts_{uuid}_{batch}.jsonl`: + +```json +{ + "item_id": "gsm_symbolic", + "rollouts": [ + { + "conversation": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "What is 2 + 3?"}, + {"role": "assistant", "content": "2 + 3 = 5\n5"} + ], + "score": 1.0 + } + ] +} +``` + +### Failed Rollouts +Saved to `data_dumps/reasoning_gym_environment_FAILED_rollouts_{uuid}_{batch}.jsonl` with same format but all scores are 0.0. + +## Logging + +The environment provides comprehensive logging: + +- **Setup**: Task discovery and initialization +- **Training**: Group scores, task selection, progress tracking +- **Data Dumping**: Save progress and file creation +- **Format Violations**: When models don't follow answer tag requirements +- **Debug Mode**: Detailed scoring and extraction information + +## Task Examples + +### Mathematics +- **GSM Symbolic**: Grade school math with symbolic reasoning +- **Basic Arithmetic**: Addition, subtraction, multiplication, division +- **Algebra**: Linear equations and polynomial manipulation + +### Logic +- **Sudoku**: Classic number placement puzzles +- **Propositional Logic**: Boolean reasoning tasks +- **Knights and Knaves**: Logic puzzles with truth-tellers and liars + +### Programming +- **ARC**: Abstract reasoning corpus visual patterns +- **Code Generation**: Simple programming challenges +- **Algorithm Design**: Sorting, searching, and optimization + +## Troubleshooting + +### Common Issues + +1. **No tasks discovered**: Ensure reasoning-gym submodule is properly initialized +2. **Import errors**: Check that requirements.txt dependencies are installed +3. **No rollouts saved**: Verify `dump_rollouts=True` and scores exceed threshold +4. **Format violations**: Models not using `` tags receive 0 scores + +### Debug Mode + +Enable debug logging for detailed information: +```python +env_config.debug_logging = True +``` + +This shows: +- Answer extraction attempts +- Scoring method comparisons +- Format violation details +- Task selection patterns + +## Performance Notes + +- **Task Selection**: Random selection ensures diverse training +- **Evaluation**: Fixed test set with deterministic seed for reproducible results +- **Memory Usage**: Buffers are cleared after saving to prevent memory leaks +- **Scoring Efficiency**: Dual-format scoring tries both methods and uses higher score + +## Contributing + +When adding new features: + +1. Maintain backward compatibility with existing configs +2. Add appropriate logging for debugging +3. Update this README with new configuration options +4. Test with both successful and failed rollout scenarios diff --git a/environments/reasoning_gym_environment/reasoning-gym b/environments/reasoning_gym_environment/reasoning-gym new file mode 160000 index 00000000..0699e2f5 --- /dev/null +++ b/environments/reasoning_gym_environment/reasoning-gym @@ -0,0 +1 @@ +Subproject commit 0699e2f5074379bb19924b7f64221ccfc8cef6ef diff --git a/environments/reasoning_gym_environment/reasoning_gym_environment.py b/environments/reasoning_gym_environment/reasoning_gym_environment.py new file mode 100644 index 00000000..fc1d6638 --- /dev/null +++ b/environments/reasoning_gym_environment/reasoning_gym_environment.py @@ -0,0 +1,1304 @@ +import json +import logging +import os +import pkgutil +import random +import re +import sys +import uuid +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import Field +from tqdm.asyncio import tqdm_asyncio + +import wandb +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + EvalHandlingEnum, + Item, + ScoredDataGroup, +) +from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer + +# Add the local reasoning-gym submodule to Python's path +# This allows `import reasoning_gym` to find the local submodule +_SUBMODULE_DIR = os.path.abspath( + os.path.join(os.path.dirname(__file__), "reasoning-gym") +) +if _SUBMODULE_DIR not in sys.path: + sys.path.insert(0, _SUBMODULE_DIR) + +# Attempt to import reasoning_gym. If not found, a warning will be issued in _get_task_names. +try: + import reasoning_gym + from reasoning_gym.utils import extract_answer +except ImportError: + reasoning_gym = None + extract_answer = None + + +system_prompt = ( + "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " + "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " + "solution prior to answering. You should enclose your thoughts and internal monologue inside " + " tags, and then provide your solution or response to the problem. After your thinking, " + "make sure to clearly provide your final answer inside tags. You can provide " + "context, explanation, etc before and after your answer tags/answer, but you must provide a single " + "answer and place it inside tags. You must provide a single answer and place it inside " + " tags." +) + +# Number of evaluation samples to generate per task for the test set +NUM_EVAL_SAMPLES_PER_TASK = 5 +# Seed for generating fixed evaluation set +EVAL_SEED = 123 + + +class ReasoningGymEnvConfig(BaseEnvConfig): + """Extended configuration for ReasoningGymEnv with additional fields.""" + + dump_rollouts: bool = Field( + default=False, + description="Whether to dump successful rollouts (above threshold) to JSONL files.", + ) + dump_failed_rollouts: bool = Field( + default=False, + description="Whether to dump failed rollouts (all 0 scores) to JSONL files for debugging.", + ) + seed: int = Field( + default=42, + description="Random seed for reproducibility.", + ) + debug_logging: bool = Field( + default=False, + description="Enable debug-level logging for more verbose output.", + ) + suppress_base_env_logs: bool = Field( + default=True, + description="Suppress verbose base environment logs (like status dict updates).", + ) + rollout_save_score_threshold: float = Field( + default=0.7, + description="Minimum score threshold for saving rollouts to data dumps. Only groups with at least one rollout above this threshold will be saved.", # noqa: E501 + ) + + def validate_config(self): + """Validate configuration parameters.""" + if not (0.0 <= self.rollout_save_score_threshold <= 1.0): + raise ValueError( + f"rollout_save_score_threshold must be between 0.0 and 1.0, got {self.rollout_save_score_threshold}" + ) + if self.rollout_save_score_threshold == 1.0: + print( + f"Warning: rollout_save_score_threshold is {self.rollout_save_score_threshold}, which may be too strict and result in no saved rollouts." # noqa: E501 + ) + + +class ReasoningGymEnv(BaseEnv): + name = "reasoning_gym" + env_config_cls = ReasoningGymEnvConfig + + def __init__( + self, + config: ReasoningGymEnvConfig, + server_configs: List[APIServerConfig], + slurm=True, + testing=False, + ): + # Validate configuration before proceeding + config.validate_config() + super().__init__(config, server_configs, slurm, testing) + + # Initialize the logger like swe_rl_env.py + self.logger = logging.getLogger(self.__class__.__name__) + if not self.logger.handlers: + # Add a basic stream handler if no handlers are configured + _handler = logging.StreamHandler() + _formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + _handler.setFormatter(_formatter) + self.logger.addHandler(_handler) + # Set logging level based on config + log_level = logging.DEBUG if self.config.debug_logging else logging.INFO + self.logger.setLevel(log_level) + # Ensure the logger itself is enabled + self.logger.disabled = False + + # Suppress base environment logs if requested + if self.config.suppress_base_env_logs: + # Set the base environment logger to WARNING level to suppress INFO logs + base_logger = logging.getLogger("atroposlib.envs.base") + base_logger.setLevel(logging.WARNING) + + # Set max_token_len for base class compatibility + self.max_token_len = self.config.max_token_length + + self.percent_correct_buffer = list() + self.eval_metrics = list() + self.rollouts_for_wandb: List[List[Tuple[str, float, str, str]]] = [] + self.task_names: List[str] = [] + self.test_items_with_scorers: List[Tuple[Dict[str, Any], Any]] = [] + self.rng = random.Random() + + self.run_uuid = str(uuid.uuid4()) + self.rollouts_to_save_buffer: List[ + Dict[str, Union[str, List[Dict[str, Union[List[Dict[str, str]], float]]]]] + ] = [] + self.processed_item_count = 0 + self.datadumps_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "data_dumps" + ) + self.save_file_batch_num = 0 + + # For saving failed rollouts (all 0 scores) for debugging + self.failed_rollouts_to_save_buffer: List[ + Dict[str, Union[str, List[Dict[str, Union[List[Dict[str, str]], float]]]]] + ] = [] + self.failed_processed_item_count = 0 + self.failed_save_file_batch_num = 0 + + @classmethod + def config_init(cls) -> Tuple[ReasoningGymEnvConfig, List[APIServerConfig]]: + env_config = ReasoningGymEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=16, + use_wandb=True, + rollout_server_url="http://localhost:8000", + total_steps=250, + batch_size=1024, + steps_per_eval=25, + max_token_length=1024 * 16, + inference_weight=1.0, + wandb_name="reasoning_gym_think", # Specific name for reasoning gym + eval_handling=EvalHandlingEnum.LIMIT_TRAIN, + eval_limit_ratio=0.1, + num_rollouts_per_group_for_logging=4, + num_rollouts_to_keep=50, + dump_rollouts=False, + dump_failed_rollouts=False, + debug_logging=False, + suppress_base_env_logs=True, + rollout_save_score_threshold=0.51, + ) + server_configs = [ + APIServerConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9004/v1", + api_key="x", + num_max_requests_at_once=32, + num_requests_for_eval=256, + ), + ] + return env_config, server_configs + + def _get_task_names(self) -> List[str]: + """ + Gets task names from the reasoning_gym DATASETS registry. + This is more reliable than dynamic discovery via pkgutil. + Falls back to a predefined list if registry access fails. + """ + if reasoning_gym is None: + print( + "ERROR: The local reasoning-gym submodule could not be imported. " + "Ensure it is present in 'atropos/environments/reasoning_gym_environment/reasoning-gym' " + "and is a valid Python package (e.g., has __init__.py)." + ) + return ["algebra/linear_1d", "arithmetic/add_or_subtract", "leg_counting"] + + discovered_tasks = [] + try: + # Access the DATASETS registry directly from reasoning_gym.factory + from reasoning_gym.factory import DATASETS + + discovered_tasks = list(DATASETS.keys()) + print( + f"Discovered {len(discovered_tasks)} tasks from DATASETS registry: {discovered_tasks[:10]}{'...' if len(discovered_tasks) > 10 else ''}" # noqa: E501 + ) + + except Exception as e: + print( + f"WARNING: Could not access DATASETS registry: {e}. " + "Falling back to manual discovery and validation." + ) + + # Fallback to pkgutil discovery if registry access fails + try: + package = reasoning_gym + if not hasattr(package, "__path__"): + raise AttributeError("'reasoning_gym' is not a package") + + base_module_path_parts = package.__name__.split(".") + + print( + f"Attempting to discover tasks in package: {package.__name__} from path: {package.__path__}" + ) + + for _, modname, ispkg in pkgutil.walk_packages( + path=package.__path__, + prefix=package.__name__ + ".", + onerror=lambda name: print( + f"Error importing module during task discovery: {name}" + ), + ): + if not ispkg: + module_parts = modname.split(".") + if ( + len(module_parts) > len(base_module_path_parts) + and module_parts[: len(base_module_path_parts)] + == base_module_path_parts + ): + + task_specific_parts = module_parts[ + len(base_module_path_parts) : + ] + + # Filter out potential private/internal modules + if any( + part.startswith("_") for part in task_specific_parts + ): + continue + + task_name = "/".join(task_specific_parts) + discovered_tasks.append(task_name) + + if discovered_tasks: + print( + f"Dynamically discovered {len(discovered_tasks)} potential task names from submodule." + ) + + except Exception as e2: + print( + f"WARNING: Pkgutil discovery also failed: {e2}. Using fallback list." + ) + + if not discovered_tasks: + print( + "WARNING: All discovery methods failed. Using fallback list of known tasks." + ) + # Complete fallback list with all available reasoning-gym tasks + fallback_tasks = [ + "ab", + "acre", + "advanced_geometry", + "aiw", + "arc_1d", + "arc_agi", + "base_conversion", + "basic_arithmetic", + "bf", + "binary_alternation", + "binary_matrix", + "bitwise_arithmetic", + "boxnet", + "caesar_cipher", + "calendar_arithmetic", + "chain_sum", + "circuit_logic", + "codeio", + "color_cube_rotation", + "complex_arithmetic", + "composite", + "count_bits", + "count_primes", + "countdown", + "course_schedule", + "cryptarithm", + "decimal_arithmetic", + "decimal_chain_sum", + "dice", + "emoji_mystery", + "family_relationships", + "figlet_font", + "fraction_simplification", + "futoshiki", + "game_of_life", + "game_of_life_halting", + "gcd", + "graph_color", + "group_anagrams", + "gsm_symbolic", + "intermediate_integration", + "isomorphic_strings", + "jugs", + "knight_swap", + "knights_knaves", + "largest_island", + "lcm", + "leg_counting", + "letter_counting", + "letter_jumble", + "list_functions", + "mahjong_puzzle", + "manipulate_matrix", + "maze", + "mini_sudoku", + "modulo_grid", + "n_queens", + "needle_haystack", + "number_filtering", + "number_format", + "number_sequence", + "number_sorting", + "palindrome_generation", + "palindrome_partitioning", + "polynomial_equations", + "polynomial_multiplication", + "pool_matrix", + "power_function", + "prime_factorization", + "products", + "propositional_logic", + "puzzle24", + "quantum_lock", + "ransom_note", + "rearc", + "rectangle_count", + "rotate_matrix", + "rotten_oranges", + "rubiks_cube", + "rush_hour", + "self_reference", + "sentence_reordering", + "shortest_path", + "simple_equations", + "simple_geometry", + "simple_integration", + "sokoban", + "spell_backward", + "spiral_matrix", + "string_insertion", + "string_manipulation", + "string_splitting", + "string_synthesis", + "sudoku", + "syllogism", + "time_intervals", + "tower_of_hanoi", + "tsumego", + "word_ladder", + "word_sequence_reversal", + "word_sorting", + "zebra_puzzles", + ] + return self._validate_discovered_tasks( + fallback_tasks, reasoning_gym if reasoning_gym else None + ) + + # Validate all discovered tasks + return self._validate_discovered_tasks(discovered_tasks, reasoning_gym) + + def _validate_discovered_tasks( + self, task_names_to_validate: List[str], rg_package_or_none: Any + ) -> List[str]: + """ + Validates a list of task names by attempting to create a dataset for each. + Args: + task_names_to_validate: List of task names (e.g., "domain/task"). + rg_package_or_none: The imported reasoning_gym package, or None if import failed. + Returns: + A list of task names that were successfully validated. + """ + if rg_package_or_none is None: + print( + "Validation SKIPPED: reasoning_gym package not available for validation." + ) + return [] + + valid_tasks = [] + print( + f"Validating {len(task_names_to_validate)} discovered/fallback task names..." + ) + for task_name in task_names_to_validate: + try: + # Use the potentially imported reasoning_gym directly for create_dataset + _ = rg_package_or_none.create_dataset(task_name, size=1, seed=0) + valid_tasks.append(task_name) + except Exception as e: + print( + f"Note: Task '{task_name}' could not be loaded (validation failed): {type(e).__name__} - {e}" + ) # noqa: E501 + pass + + if not valid_tasks and task_names_to_validate: + print( + "WARNING: Validation of discovered/fallback tasks failed for all. This might indicate a systematic issue." # noqa: E501 + ) + # If validation fails for all, it's safer to return an empty list or a minimal known-good set. + # However, if task_names_to_validate was non-empty, returning it raw might be a last resort if validation itself is flawed. # noqa: E501 + # For safety, let's prefer an empty list if validation fails completely. + print("No tasks passed validation.") + return [] + + if not valid_tasks: + print( + "CRITICAL WARNING: No valid reasoning-gym tasks could be loaded. Environment may not function." + ) + # Return an absolute minimal, known-good task if everything else fails. + # This specific task must exist and be loadable in reasoning_gym. + try: + _ = rg_package_or_none.create_dataset("leg_counting", size=1, seed=0) + print("Falling back to absolute minimal task: 'leg_counting'") + return ["leg_counting"] + except Exception: + print("Absolute fallback 'leg_counting' also failed to load.") + return [] + + print(f"Validated {len(valid_tasks)} tasks for use.") + return valid_tasks + + async def setup(self): + # The reasoning_gym import is now handled at the top with sys.path modification. + if reasoning_gym is None: + raise ImportError( + "reasoning-gym library could not be imported from the local submodule. " + "This environment cannot function. Check submodule presence and integrity." + ) + + self.logger.info("Setting up ReasoningGym environment...") + + self.task_names = ( + self._get_task_names() + ) # _get_task_names now uses self._validate_discovered_tasks + if not self.task_names: + raise ValueError( + "No reasoning_gym tasks could be loaded. Environment setup failed." + ) + + self.logger.info( + f"ReasoningGymEnv: Initialized with {len(self.task_names)} tasks." + ) + self.logger.info(f"Sample tasks: {self.task_names[:5]}") + + # Seed for main RNG used in get_next_item + # The seed for reasoning_gym dataset creation will be self.iter + self.rng.seed(self.config.seed) + self.iter = 0 + + # Create a fixed test set for evaluation + self.logger.info("Generating fixed test set for evaluation...") + eval_tasks_sample = self.rng.sample( + self.task_names, min(len(self.task_names), 20) + ) # Sample 20 tasks for eval + + for task_name in tqdm_asyncio(eval_tasks_sample, desc="Creating eval dataset"): + try: + # Each task gets its own dataset instance for evaluation + # Using a fixed seed for reproducibility of the test set + dataset = reasoning_gym.create_dataset( + task_name, size=NUM_EVAL_SAMPLES_PER_TASK, seed=EVAL_SEED + ) + for item in dataset: + self.test_items_with_scorers.append((item, dataset)) + except Exception as e: + self.logger.warning( + f"Could not create eval dataset for task '{task_name}': {e}" + ) + + if not self.test_items_with_scorers: + self.logger.warning( + "No evaluation items could be generated. Evaluation might be skipped or fail." + ) + else: + self.logger.info( + f"Generated {len(self.test_items_with_scorers)} items for the evaluation test set." + ) + + self.logger.info( + "ReasoningGym environment setup complete. Ready to start training!" + ) + self.logger.info( + f"Configuration: group_size={self.config.group_size}, max_token_length={self.config.max_token_length}, steps_per_eval={self.config.steps_per_eval}" # noqa: E501 + ) + if self.config.dump_rollouts: + self.logger.info( + f"Data dumping enabled with score threshold: {self.config.rollout_save_score_threshold}" + ) + self.logger.info( + "Using strict tag enforcement: models must use exactly one tag or receive 0 score" + ) + self.logger.info( + "Using dual-format scoring for valid answers: trying both raw answers and -tagged answers, using higher score" # noqa: E501 + ) + + async def get_next_item(self) -> Optional[Item]: + """ + Get the next training item by randomly selecting a reasoning_gym task and generating a sample. + Returns: + A tuple: (prompt_messages, reasoning_gym_item, reasoning_gym_dataset_object) + prompt_messages: Formatted for the language model. + reasoning_gym_item: The raw item from reasoning_gym (dict with 'question', 'answer', 'metadata'). + reasoning_gym_dataset_object: The dataset object from which the item was generated (for scoring). + """ + if not self.task_names: + return None # Should not happen if setup is correct + + selected_task_name = self.rng.choice(self.task_names) + + try: + # Create a new dataset instance for each item to ensure variety if tasks have internal state + # Use self.iter for seed to get different questions over time + current_seed = self.config.seed + self.iter # Vary seed per item + dataset_obj = reasoning_gym.create_dataset( + selected_task_name, size=1, seed=current_seed + ) + rg_item = next(iter(dataset_obj)) # Get the single item + + # Log task selection every 10 items to avoid spam + if self.iter % 10 == 0: + self.logger.info( + f"Selected task: {selected_task_name} (iteration {self.iter})" + ) + + except Exception as e: + self.logger.warning( + f"Error generating item for task {selected_task_name} with seed {current_seed}: {e}" + ) + return None # Skip this item if generation fails + + self.iter += 1 + + question_text = rg_item["question"] + + # Construct prompt messages + prompt_messages = [ + frozenset({"role": "system", "content": system_prompt}.items()), + frozenset({"role": "user", "content": question_text}.items()), + ] + + # The 'answer' here is the data needed for scoring later + return (tuple(prompt_messages), rg_item, dataset_obj) + + def _extract_final_answer(self, text: str) -> Optional[str]: + """ + Extracts the content from tags using reasoning-gym's extract_answer function. + Enforces strict compliance with answer tag instructions. + Returns None if model doesn't follow instructions properly (no answer tags or multiple answer tags outside think). + """ # noqa: E501 + if extract_answer is None: + # If reasoning-gym not available, we can't enforce the format + if self.config.debug_logging: + self.logger.debug( + "reasoning-gym extract_answer not available, cannot enforce answer tag format" + ) + return None + + # Check for multiple tags outside of blocks + # First, remove all ... content to check only the "outside" content + think_pattern = r".*?" + text_outside_think = re.sub( + think_pattern, "", text, flags=re.DOTALL | re.IGNORECASE + ) + + # Count tags in the content outside blocks + answer_pattern = r".*?" + answer_matches_outside_think = re.findall( + answer_pattern, text_outside_think, flags=re.DOTALL | re.IGNORECASE + ) + + if len(answer_matches_outside_think) > 1: + if self.config.debug_logging: + self.logger.debug( + f"Model provided {len(answer_matches_outside_think)} answer tags outside think blocks - failing for not following single answer instruction" # noqa: E501 + ) + return None + + # Try to extract from tags using reasoning-gym's function + answer_content = extract_answer(text, tag_name="answer", strip=True) + if answer_content is not None: + if self.config.debug_logging: + self.logger.debug( + f"Successfully extracted answer from tags: '{answer_content[:100]}{'...' if len(answer_content) > 100 else ''}'" # noqa: E501 + ) + return answer_content + + # No valid answer tags found - model failed to follow instructions + if self.config.debug_logging: + self.logger.debug( + "No tags found - model failed to follow answer format instructions" + ) + return None + + def _score_answer_with_both_formats( + self, model_answer: str, rg_item: dict, dataset_obj: Any + ) -> float: + """ + Score the answer using both formats (with and without tags) and return the higher score. + This handles verifiers that expect different formats. + """ + # Format 1: Just the answer content + try: + score1 = dataset_obj.score_answer(answer=model_answer, entry=rg_item) + score1 = max(0.0, min(1.0, float(score1))) + except Exception as e: + task_name = rg_item.get("metadata", {}).get("source_dataset", "unknown") + self.logger.debug( + f"Error scoring answer format 1 for task {task_name}: {e}" + ) + score1 = 0.0 + + # Format 2: Answer wrapped in tags + answer_with_tags = f"{model_answer}" + try: + score2 = dataset_obj.score_answer(answer=answer_with_tags, entry=rg_item) + score2 = max(0.0, min(1.0, float(score2))) + except Exception as e: + task_name = rg_item.get("metadata", {}).get("source_dataset", "unknown") + self.logger.debug( + f"Error scoring answer format 2 for task {task_name}: {e}" + ) + score2 = 0.0 + + # Return the higher score + final_score = max(score1, score2) + + # Log which format worked better (only in debug mode) + if self.config.debug_logging and score1 != score2: + task_name = rg_item.get("metadata", {}).get("source_dataset", "unknown") + if score1 > score2: + self.logger.debug( + f"Task {task_name}: Raw answer format scored higher ({score1:.3f} vs {score2:.3f})" + ) + else: + self.logger.debug( + f"Task {task_name}: Tagged answer format scored higher ({score2:.3f} vs {score1:.3f})" + ) + + return final_score + + async def score( + self, + rollout_group_data: List[Tuple[Tuple[Dict[str, str]], Dict[str, Any], Any]], + ) -> Optional[ScoredDataGroup]: + """ + Scores a group of rollouts using reasoning_gym's score_answer method. + Args: + rollout_group_data: A list of tuples, where each tuple contains: + - trajectory_messages: The full conversation history for the rollout. + - rg_item: The original reasoning_gym item (contains 'question', 'answer', 'metadata'). + - dataset_obj: The reasoning_gym dataset object used to generate and score the item. + Returns: + ScoredDataGroup with scores between 0.0 and 1.0, or None if no valid items + """ + scores_container = ScoredDataGroup() + scores_container["tokens"] = list() + scores_container["masks"] = list() + scores_container["scores"] = list() + + if not rollout_group_data: + return None + + rg_item_for_group = rollout_group_data[0][1] + dataset_obj_for_group = rollout_group_data[0][2] + + # Shuffle to avoid bias in selection + random.shuffle(rollout_group_data) + + for trajectory_messages, _, _ in rollout_group_data: + model_full_response = trajectory_messages[-1]["content"] + + # Extract the part of the response that should be the answer + model_answer_to_score = self._extract_final_answer(model_full_response) + + # If extraction failed (model didn't follow instructions), give 0 score + if model_answer_to_score is None: + reward_0_to_1 = 0.0 + if self.config.debug_logging: + task_name = rg_item_for_group.get("metadata", {}).get( + "source_dataset", "unknown" + ) + self.logger.debug( + f"Task {task_name}: Giving 0 score due to failed answer extraction (didn't follow format)" + ) + else: + # Use our dual-format scoring method that tries both raw answer and tagged answer + reward_0_to_1 = self._score_answer_with_both_formats( + model_answer_to_score, rg_item_for_group, dataset_obj_for_group + ) + + # Tokenize the conversation for learning + out_dict = tokenize_for_trainer(self.tokenizer, trajectory_messages) + tokens = out_dict["tokens"] + masks = out_dict["masks"] + + # Remove examples with insufficient context + if len([1 for i in masks if i != -100]) < 10: + continue + + scores_container["tokens"].append(tokens) + scores_container["masks"].append(masks) + scores_container["scores"].append(reward_0_to_1) + + if len(scores_container["tokens"]) >= self.config.group_size: + break + + if not scores_container["tokens"]: + self.logger.warning( + "No valid items were scored in this group - all items had insufficient context or failed scoring" + ) + return None + + # Record success rate metrics (convert to binary for percent_correct tracking) + for score_val in scores_container["scores"]: + self.percent_correct_buffer.append(1.0 if score_val >= 0.5 else 0.0) + + # Calculate and log average score for the current group + current_scores = scores_container.get("scores", []) + if current_scores: + average_score = sum(current_scores) / len(current_scores) + task_name = rg_item_for_group.get("metadata", {}).get( + "source_dataset", "unknown_task" + ) + log_message_main = ( + f"Task: {task_name} | Group average score: {average_score:.4f}" + ) + if all(s >= 0.5 for s in current_scores): + self.logger.info(f"{log_message_main} (All correct in this group!)") + elif all(s == 0.0 for s in current_scores): + self.logger.info(f"{log_message_main} (All failed - no valid answers!)") + elif all(s < 0.5 for s in current_scores): + self.logger.info( + f"{log_message_main} (All incorrect but some partial credit!)" + ) + else: + self.logger.info(log_message_main) + + # Apply length penalty if all responses are correct (score >= 0.5) + if all(s >= 0.5 for s in scores_container["scores"]): + # Calculate token lengths + token_lengths = [len(token) for token in scores_container["tokens"]] + if max(token_lengths) == 0: + return None + + # Get max allowed token length from config + max_allowed_length = self.config.max_token_length + # Set threshold at 75% of max_token_length + length_threshold = max_allowed_length * 0.75 + + # Apply modified length penalty with threshold + new_scores = [] + penalties_applied = 0 + for i, length in enumerate(token_lengths): + original_score = scores_container["scores"][i] + if length <= length_threshold: + new_scores.append(original_score) + else: + # Calculate how far we are between threshold and max as a percentage + percentage_of_range = (length - length_threshold) / ( + max_allowed_length - length_threshold + ) + percentage_of_range = min(percentage_of_range, 1.0) + # Apply linear penalty scaling from original_score down to 0.0 + penalized_score = original_score * (1.0 - percentage_of_range) + new_scores.append(penalized_score) + penalties_applied += 1 + + if penalties_applied > 0: + avg_length = sum(token_lengths) / len(token_lengths) + self.logger.debug( + f"Applied length penalty to {penalties_applied}/{len(token_lengths)} responses (avg length: {avg_length:.0f}, threshold: {length_threshold:.0f})" # noqa: E501 + ) + + scores_container["scores"] = new_scores + + # Check if all scores are the same (no learning signal) + if all( + scores_container["scores"][0] == score + for score in scores_container["scores"] + ): + self.logger.debug( + f"All scores in group are identical ({scores_container['scores'][0]:.4f}) - no learning signal, skipping group" # noqa: E501 + ) + + # Before returning None, check if this is a completely failed group (all 0.0 scores) for debugging + if self.config.dump_failed_rollouts and all( + score == 0.0 for score in scores_container["scores"] + ): + self.logger.debug( + "Saving failed group (all 0 scores) for debugging analysis" + ) + await self._save_failed_group_for_debugging( + rollout_group_data, scores_container + ) + + return None + + return scores_container + + async def _save_failed_group_for_debugging( + self, rollout_group_data, scores_container + ): + """Helper method to save failed groups (all 0 scores) for debugging analysis.""" + failed_rollouts_with_scores_to_save = [] + + # Build the failed rollouts data structure + for i, (trajectory_messages, rg_item, dataset_obj) in enumerate( + rollout_group_data + ): + if i < len(scores_container["scores"]): + score_for_rollout = scores_container["scores"][i] + failed_rollouts_with_scores_to_save.append( + { + "conversation": trajectory_messages, # Full conversation history + "score": score_for_rollout, + } + ) + + if failed_rollouts_with_scores_to_save: + # Extract item info for logging - get from first rollout + _, rg_item, _ = rollout_group_data[0] + item_id = rg_item.get("metadata", {}).get("source_dataset", "unknown_task") + + failed_item_data_to_save = { + "item_id": item_id, + "rollouts": failed_rollouts_with_scores_to_save, + } + self.failed_rollouts_to_save_buffer.append(failed_item_data_to_save) + self.failed_processed_item_count += 1 + + # Calculate progress toward next failed save + failed_batch_progress = ( + self.failed_processed_item_count % 50 + ) # Save failed every 50 items + if failed_batch_progress == 0: + failed_batch_progress = ( + 50 # Show 50/50 instead of 0/50 when we hit the threshold + ) + + # Log progress every 10 failed items or when we hit the save threshold + if failed_batch_progress % 10 == 0 or failed_batch_progress == 50: + self.logger.info( + f"Failed rollouts progress: {failed_batch_progress}/50 items buffered " + f"(Total failed processed: {self.failed_processed_item_count}, Failed buffer size: {len(self.failed_rollouts_to_save_buffer)})" # noqa: E501 + ) + + # Check if it's time to save a batch of failed rollouts (every 50 instead of 100) + if ( + self.config.dump_failed_rollouts + and self.failed_processed_item_count % 50 == 0 + and self.failed_processed_item_count > 0 + ): + failed_log_msg = ( + f"Reached {self.failed_processed_item_count} failed items. " + f"Triggering save for {len(self.failed_rollouts_to_save_buffer)} failed items " + f"(each with multiple failed rollouts)." + ) + self.logger.info(failed_log_msg) + await self._save_failed_rollouts_to_jsonl() + + async def collect_trajectories( + self, item: Item + ) -> Tuple[Optional[ScoredDataGroup], List]: + """ + Generate model responses for the given item and then score them. + Args: + item: A tuple from get_next_item: (prompt_messages, rg_item, dataset_obj) + """ + prompt_messages_tuple, rg_item, dataset_obj = item + + # Apply chat template to convert messages to a single string + # Ensure prompt_messages_tuple is correctly formatted list of dicts + messages_for_template = [dict(msg_fset) for msg_fset in prompt_messages_tuple] + + prompt_str = self.tokenizer.apply_chat_template( + messages_for_template, add_generation_prompt=True, tokenize=False + ) + + # Calculate max_tokens like tool_calling_server + prompt_tokens = len(self.tokenizer.encode(prompt_str)) + max_tokens = min(1024 * 15, self.config.max_token_length - prompt_tokens) + + completions = await self.server.completion( + prompt=prompt_str, + n=self.config.group_size, + max_tokens=max_tokens, + temperature=0.8, + ) + + to_score_list = [] + for choice in completions.choices: + self.completion_lengths.append(len(choice.text)) + + # Create full trajectory messages for this choice + current_trajectory_messages = list(messages_for_template) + current_trajectory_messages.append( + {"role": "assistant", "content": choice.text} + ) + + to_score_list.append( + (tuple(current_trajectory_messages), rg_item, dataset_obj) + ) + + scored_data_group = await self.score(to_score_list) + + # If rollouts were generated and scored, and data dumping is enabled, prepare them for saving + if scored_data_group and self.config.dump_rollouts: + # Only save groups that have at least one rollout with score > threshold + group_scores = scored_data_group.get("scores", []) + threshold = self.config.rollout_save_score_threshold + if any(score > threshold for score in group_scores): + self.logger.debug( + f"Saving group with scores: {[f'{s:.3f}' for s in group_scores]} (has high-quality rollout, threshold: {threshold})" # noqa: E501 + ) + rollouts_with_scores_to_save = [] + + num_scored_rollouts = len(group_scores) + for i in range(num_scored_rollouts): + conversation_messages = to_score_list[i][0] + score_for_rollout = group_scores[i] + rollouts_with_scores_to_save.append( + { + "conversation": conversation_messages, + "score": score_for_rollout, + } + ) + + if rollouts_with_scores_to_save: + # Extract item info for logging + _, rg_item, _ = item + item_id = rg_item.get("metadata", {}).get( + "source_dataset", "unknown_task" + ) + + item_data_to_save = { + "item_id": item_id, + "rollouts": rollouts_with_scores_to_save, + } + self.rollouts_to_save_buffer.append(item_data_to_save) + self.processed_item_count += 1 + + # Calculate progress toward next save + current_batch_progress = self.processed_item_count % 100 + if current_batch_progress == 0: + current_batch_progress = 100 + + # Log progress every 10 items or when we hit the save threshold + if current_batch_progress % 10 == 0 or current_batch_progress == 100: + self.logger.info( + f"Data dump progress: {current_batch_progress}/100 items buffered " + f"(Total processed: {self.processed_item_count}, Buffer size: {len(self.rollouts_to_save_buffer)})" # noqa: E501 + ) + + # Check if it's time to save a batch of rollouts + if ( + self.config.dump_rollouts + and self.processed_item_count % 100 == 0 + and self.processed_item_count > 0 + ): + log_msg = ( + f"Reached {self.processed_item_count} processed items. " + f"Triggering save for {len(self.rollouts_to_save_buffer)} items " + f"(each with multiple scored rollouts)." + ) + self.logger.info(log_msg) + await self._save_rollouts_to_jsonl() + else: + max_score = max(group_scores) if group_scores else 0.0 + self.logger.debug( + f"Skipping group save - no high-quality rollouts (max score: {max_score:.3f}, threshold: {threshold})" # noqa: E501 + ) + + to_backlog = [] + return scored_data_group, to_backlog + + async def _save_rollouts_to_jsonl(self): + """Saves the buffered rollouts to a JSONL file in the datadumps directory.""" + if not self.rollouts_to_save_buffer: + self.logger.info("No rollouts in buffer to save.") + return + + try: + if not os.path.exists(self.datadumps_dir): + os.makedirs(self.datadumps_dir) + self.logger.info(f"Created directory: {self.datadumps_dir}") + except OSError as e: + self.logger.error(f"Error creating directory {self.datadumps_dir}: {e}") + return + + file_path = os.path.join( + self.datadumps_dir, + f"reasoning_gym_environment_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl", + ) + + try: + with open(file_path, "w") as f: + for rollout_dict in self.rollouts_to_save_buffer: + json.dump(rollout_dict, f) + f.write("\n") + self.logger.info( + f"Successfully saved {len(self.rollouts_to_save_buffer)} rollouts to {file_path}" + ) + self.rollouts_to_save_buffer.clear() + self.save_file_batch_num += 1 + except IOError as e: + self.logger.error(f"Error writing rollouts to {file_path}: {e}") + except Exception as e: + self.logger.error( + f"An unexpected error occurred while saving rollouts to {file_path}: {e}" + ) + + async def _save_failed_rollouts_to_jsonl(self): + """Saves the buffered failed rollouts (all 0 scores) to a JSONL file for debugging.""" + if not self.failed_rollouts_to_save_buffer: + self.logger.info("No failed rollouts in buffer to save.") + return + + try: + if not os.path.exists(self.datadumps_dir): + os.makedirs(self.datadumps_dir) + self.logger.info(f"Created directory: {self.datadumps_dir}") + except OSError as e: + self.logger.error(f"Error creating directory {self.datadumps_dir}: {e}") + return + + file_path = os.path.join( + self.datadumps_dir, + f"reasoning_gym_environment_FAILED_rollouts_{self.run_uuid}_{self.failed_save_file_batch_num:04d}.jsonl", + ) + + try: + with open(file_path, "w") as f: + for rollout_dict in self.failed_rollouts_to_save_buffer: + json.dump(rollout_dict, f) + f.write("\n") + self.logger.info( + f"Successfully saved {len(self.failed_rollouts_to_save_buffer)} FAILED rollouts to {file_path}" + ) + self.failed_rollouts_to_save_buffer.clear() + self.failed_save_file_batch_num += 1 + except IOError as e: + self.logger.error(f"Error writing failed rollouts to {file_path}: {e}") + except Exception as e: + self.logger.error( + f"An unexpected error occurred while saving failed rollouts to {file_path}: {e}" + ) + + async def rollout_and_score_eval( + self, test_data_tuple: Tuple[Dict[str, Any], Any] + ) -> float: + """ + Performs a rollout for a single evaluation item and scores it. + Args: + test_data_tuple: A tuple (rg_item, dataset_obj) from self.test_items_with_scorers. + Returns: + Score (1.0 for correct, 0.0 for incorrect/error). + """ + rg_item, dataset_obj = test_data_tuple + question_text = rg_item["question"] + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question_text}, + ] + + prompt_str = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + # Calculate max_tokens like tool_calling_server + prompt_tokens = len(self.tokenizer.encode(prompt_str)) + max_tokens = min(1024 * 15, self.config.max_token_length - prompt_tokens) + + completion = await self.server.completion( + prompt=prompt_str, + n=1, + max_tokens=max_tokens, + temperature=0.1, + split="eval", + ) + + model_full_response = completion.choices[0].text + model_answer_to_score = self._extract_final_answer(model_full_response) + + # If extraction failed (model didn't follow instructions), give 0 score + if model_answer_to_score is None: + if self.config.debug_logging: + task_name = rg_item.get("metadata", {}).get("source_dataset", "unknown") + self.logger.debug( + f"Eval - Task {task_name}: Giving 0 score due to failed answer extraction (didn't follow format)" + ) + return 0.0 + + # Use our dual-format scoring method for evaluation as well + return self._score_answer_with_both_formats( + model_answer_to_score, rg_item, dataset_obj + ) + + async def evaluate(self, *args, **kwargs): + self.logger.info("Starting evaluation...") + if not self.test_items_with_scorers: + self.logger.warning("No test items available for evaluation. Skipping.") + self.eval_metrics.append(("eval/percent_correct", 0.0)) + return + + eval_tasks = [ + self.rollout_and_score_eval(item_tuple) + for item_tuple in self.test_items_with_scorers + ] + + self.logger.info( + f"Starting evaluation on {len(self.test_items_with_scorers)} items..." + ) + scores = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating") + + if not scores: + percent_correct = 0.0 + else: + percent_correct = sum(scores) / len(scores) + + self.eval_metrics.append(("eval/percent_correct", percent_correct)) + self.logger.info(f"Evaluation finished. Percent correct: {percent_correct:.4f}") + + async def add_rollouts_for_wandb( + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Item = None, + ): + if item is None or scored_data is None or not scored_data.get("tokens"): + return + + _, rg_item, _ = item + expected_answer = str(rg_item.get("answer", "N/A")) + task_name = str( + rg_item.get("metadata", {}).get( + "source_dataset", rg_item.get("task_name", "unknown_task") + ) + ) + + # save rollout to trajectory + num_keep = self.config.num_rollouts_per_group_for_logging + if num_keep == -1: + num_keep = self.config.group_size + + # Make sure there's data to log + num_keep = min(num_keep, len(scored_data["tokens"])) + if num_keep == 0: + return + + current_rollouts = [] + for i in range(num_keep): + # Ensure tokens and scores have the same length + if i < len(scored_data["tokens"]) and i < len(scored_data["scores"]): + # Decode the full trajectory including prompt and model response + full_text = self.tokenizer.decode( + scored_data["tokens"][i], skip_special_tokens=True + ) + score_val = scored_data["scores"][i] + current_rollouts.append( + (full_text, score_val, expected_answer, task_name) + ) + else: + print( + f"Warning: Mismatch in lengths of tokens/scores for wandb logging at index {i}." + ) + + self.rollouts_for_wandb.append(current_rollouts) + + if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: + self.rollouts_for_wandb.pop(0) + + async def create_rollout_table(self, wandb_metrics: Dict) -> Dict: + if len(self.rollouts_for_wandb) > 0: + table = wandb.Table( + columns=["full_text", "score", "expected_answer", "task_name"] + ) + for group_rollouts in self.rollouts_for_wandb: + for rollout_tuple in group_rollouts: + # Ensure rollout_tuple has exactly 4 elements as defined in columns + if len(rollout_tuple) == 4: + table.add_data(*rollout_tuple) + else: + print( + f"Warning: Skipping malformed rollout_tuple for wandb table: {rollout_tuple}" + ) + wandb_metrics["train/rollouts"] = table + + self.rollouts_for_wandb = [] + return wandb_metrics + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """ + Log to wandb with comprehensive metrics. + """ + if wandb_metrics is None: + wandb_metrics = dict() + + # Try to calculate percent_correct, skip if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + pass + + self.percent_correct_buffer = list() + + # Add eval metrics + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + + # Add rollout table + wandb_metrics = await self.create_rollout_table(wandb_metrics) + + # Call superclass wandb_log + await super().wandb_log(wandb_metrics) + + def save_checkpoint(self, step, data=None): + """Save checkpoint including current iteration number, completion lengths, and data dumping state.""" + if data is None: + data = {} + data["iter"] = self.iter + data["rng_state"] = self.rng.getstate() + data["completion_lengths"] = self.completion_lengths + data["processed_item_count"] = self.processed_item_count + data["save_file_batch_num"] = self.save_file_batch_num + data["failed_processed_item_count"] = self.failed_processed_item_count + data["failed_save_file_batch_num"] = self.failed_save_file_batch_num + super().save_checkpoint(step, data) + + def load_checkpoint(self): + """Load checkpoint including iteration number, completion lengths, and data dumping state.""" + # Call the base class method first to load the data + super().load_checkpoint() + + # The base class loads data into attributes, so we can access them directly + # if they were saved in save_checkpoint + if hasattr(self, "iter"): + # Data was loaded successfully, no need to do anything else + pass + + async def close(self): + """Clean up and save any remaining rollouts before exiting.""" + self.logger.info( + "Closing ReasoningGymEnv. Attempting to save any remaining rollouts..." + ) + if ( + self.config.dump_rollouts and self.rollouts_to_save_buffer + ): # Check if there's anything to save + self.logger.info( + f"Found {len(self.rollouts_to_save_buffer)} rollouts in buffer. Saving now." + ) + await self._save_rollouts_to_jsonl() + else: + self.logger.info("No rollouts in buffer to save upon closing.") + + # Also save any remaining failed rollouts + if self.config.dump_failed_rollouts and self.failed_rollouts_to_save_buffer: + self.logger.info( + f"Found {len(self.failed_rollouts_to_save_buffer)} failed rollouts in buffer. Saving now." + ) + await self._save_failed_rollouts_to_jsonl() + else: + self.logger.info("No failed rollouts in buffer to save upon closing.") + + # Call the superclass's close method if it exists + if hasattr(super(), "close"): + await super().close() + self.logger.info("ReasoningGymEnv closed.") + + +if __name__ == "__main__": + ReasoningGymEnv.cli() diff --git a/environments/reasoning_gym_environment/requirements.txt b/environments/reasoning_gym_environment/requirements.txt new file mode 100644 index 00000000..6436dec9 --- /dev/null +++ b/environments/reasoning_gym_environment/requirements.txt @@ -0,0 +1,13 @@ +numpy>=1.21.0 +sympy>=1.9.0 +matplotlib>=3.5.0 +pillow>=8.3.0 + +tqdm>=4.62.0 +asyncio-throttle>=1.0.0 + +jsonlines>=2.0.0 + +colorlog>=6.0.0 + +# Python version requirement (reasoning-gym requires Python >= 3.10)