diff --git a/environments/infinimath/infinimath_env.py b/environments/infinimath/infinimath_env.py index 5e89a56c..758e015f 100644 --- a/environments/infinimath/infinimath_env.py +++ b/environments/infinimath/infinimath_env.py @@ -45,25 +45,21 @@ Remember to format your final answer correctly as this is important for evaluati class InfiniteMathEnvConfig(BaseEnvConfig): """Configuration for the InfiniteMath environment.""" - # Curriculum parameters starting_level: int = 1 progress_threshold: float = 0.8 min_evaluations: int = 5 - # Environment parameters max_attempts_per_problem: int = 3 correct_reward: float = 1.0 incorrect_reward: float = -1.0 - think_block_bonus: float = 0.2 # Bonus for a well-formed think block - boxed_answer_bonus: float = 0.2 # Bonus for a well-formed boxed answer + think_block_bonus: float = 0.2 + boxed_answer_bonus: float = 0.2 - # Length penalty parameters apply_length_penalty: bool = True length_threshold_ratio: float = ( - 0.5 # Percentage of max_token_length before penalties apply + 0.5 ) - # Completion parameters temperature: float = 0.7 top_p: float = 0.9 @@ -79,40 +75,34 @@ class InfiniteMathEnv(BaseEnv): testing=False, ): super().__init__(config, server_configs, slurm, testing) - self.config = config # Override with our specific config class + self.config = config - # Initialize tracking metrics self.percent_correct_buffer = [] self.level_correct_buffer = { i: [] for i in range(1, 8) - } # Track correctness for each level + } self.eval_metrics = [] - # Curriculum will be initialized in setup() self.curriculum = None - # Set the system prompt self.system_prompt = system_prompt async def setup(self): """Initialize the environment and curriculum.""" logger.info("Setting up InfiniteMathEnv") - # Initialize curriculum self.curriculum = MathCurriculum( starting_level=self.config.starting_level, progress_threshold=self.config.progress_threshold, min_evaluations=self.config.min_evaluations, ) - # Generate some test problems for each level for evaluation self.eval_problems = {} for level in range(1, 8): self.eval_problems[level] = [] temp_curriculum = MathCurriculum(starting_level=level) - # Generate 10 test problems for each level attempts = 0 - max_attempts_per_level = 20 # Try at most 20 problems to get 10 valid ones + max_attempts_per_level = 20 while ( len(self.eval_problems[level]) < 10 @@ -120,9 +110,8 @@ class InfiniteMathEnv(BaseEnv): ): try: problem, solution, generator_id = temp_curriculum.get_problem() - # Strip LaTeX delimiters - problem = self.strip_latex_delimiters(problem) - solution = self.strip_latex_delimiters(solution) + problem = self._strip_latex_delimiters(problem) + solution = self._strip_latex_delimiters(solution) self.eval_problems[level].append((problem, solution, generator_id)) except Exception as e: logger.warning( @@ -134,7 +123,6 @@ class InfiniteMathEnv(BaseEnv): f"Generated {len(self.eval_problems[level])} evaluation problems for level {level}" ) - # If any levels have no problems, add a simple fallback for level in range(1, 8): if not self.eval_problems[level]: logger.warning( @@ -169,9 +157,8 @@ class InfiniteMathEnv(BaseEnv): else: self.eval_problems[level].append(("What is |3 - 10|?", "7", 71)) - def strip_latex_delimiters(self, text: str) -> str: + def _strip_latex_delimiters(self, text: str) -> str: """Strip LaTeX delimiters ($...$) from text.""" - # Handle both inline expressions $...$ and expressions that make up the entire string return re.sub(r"\$(.*?)\$", r"\1", text) def save_checkpoint(self, step, data=None): @@ -179,7 +166,6 @@ class InfiniteMathEnv(BaseEnv): if data is None: data = {} - # Save curriculum state data["curriculum_level"] = self.curriculum.get_current_level() data["performance_history"] = { str(k): v for k, v in self.curriculum.performance_history.items() @@ -191,19 +177,16 @@ class InfiniteMathEnv(BaseEnv): """Load curriculum state from checkpoint.""" super().load_checkpoint() - # Check if we have curriculum data in the checkpoint checkpoint_path = f"{self.checkpoint_dir}/env_checkpoints/{self.wandb_prepend}/step-{self.curr_step}.json" try: with open(checkpoint_path, "r") as f: data = json.load(f) - # Restore curriculum state if available if "curriculum_level" in data: level = data["curriculum_level"] self.curriculum.current_level = level if "performance_history" in data: - # Convert string keys back to integers self.curriculum.performance_history = { int(k): v for k, v in data["performance_history"].items() } @@ -215,7 +198,6 @@ class InfiniteMathEnv(BaseEnv): if wandb_metrics is None: wandb_metrics = {} - # Log overall correct percentage try: wandb_metrics["train/percent_correct"] = sum( self.percent_correct_buffer @@ -223,7 +205,6 @@ class InfiniteMathEnv(BaseEnv): except ZeroDivisionError: pass - # Log per-level metrics for level, buffer in self.level_correct_buffer.items(): if buffer: wandb_metrics[f"train/level_{level}_correct"] = sum(buffer) / len( @@ -231,7 +212,6 @@ class InfiniteMathEnv(BaseEnv): ) wandb_metrics[f"train/level_{level}_count"] = len(buffer) - # Log current level and curriculum progress if self.curriculum: current_level = self.curriculum.get_current_level() max_level = max(self.curriculum.DIFFICULTY_LEVELS.keys()) @@ -242,12 +222,10 @@ class InfiniteMathEnv(BaseEnv): current_level / max_level ) * 100 - # Log level description wandb_metrics["curriculum/level_description"] = ( self.curriculum.get_level_description() ) - # Log performance history for current level if current_level in self.curriculum.performance_history: history = self.curriculum.performance_history[current_level] if history: @@ -266,49 +244,25 @@ class InfiniteMathEnv(BaseEnv): 0, self.curriculum.progress_threshold - success_rate ) - # Log reward function metrics - # REMOVED: Specific reward function config logging as it's not used anymore - # if hasattr(self, "reward_function") and self.wandb: - # if hasattr(self.reward_function, "set_wandb_logger"): - # self.reward_function.set_wandb_logger(self.wandb) - - # # Log the reward configurations - # if isinstance(self.config.reward_functions, list) and self.config.reward_functions: - # # Log the reward configuration - # wandb_metrics["reward/format_reward_enabled"] = "format" in self.config.reward_functions - # wandb_metrics["reward/boxed_reward_enabled"] = "boxed" in self.config.reward_functions - - # if hasattr(self.config, "format_reward_weight"): - # wandb_metrics["reward/format_reward_weight"] = self.config.format_reward_weight - - # if hasattr(self.config, "boxed_reward_weight"): - # wandb_metrics["reward/boxed_reward_weight"] = self.config.boxed_reward_weight - - # Add eval metrics for item in self.eval_metrics: wandb_metrics[item[0]] = item[1] - # Reset buffers self.percent_correct_buffer = [] for level in self.level_correct_buffer: self.level_correct_buffer[level] = [] self.eval_metrics = [] - # Call the parent method to handle remaining metrics await super().wandb_log(wandb_metrics) async def get_next_item(self): """Get the next problem based on current curriculum level.""" problem, solution, generator_id = self.curriculum.get_problem() - # Strip LaTeX delimiters from problem and solution - problem = self.strip_latex_delimiters(problem) - solution = self.strip_latex_delimiters(solution) + problem = self._strip_latex_delimiters(problem) + solution = self._strip_latex_delimiters(solution) - # Create a message with the problem prompt = tuple([frozenset({"role": "user", "content": problem}.items())]) - # Return the problem with metadata return (prompt, solution, generator_id) async def evaluate(self, *args, **kwargs): @@ -316,7 +270,6 @@ class InfiniteMathEnv(BaseEnv): current_level = self.curriculum.get_current_level() logger.info(f"Starting evaluation for curriculum level {current_level}") - # Only evaluate problems at the current level eval_tasks = [] eval_generator_ids = [] if current_level in self.eval_problems: @@ -332,11 +285,9 @@ class InfiniteMathEnv(BaseEnv): ) return [] - # Run evaluation tasks logger.info(f"Evaluating {len(eval_tasks)} problems at level {current_level}") results = await asyncio.gather(*eval_tasks) - # Calculate accuracy for the current level correct_count = sum(1 for _, is_correct in results if is_correct) total_count = len(results) accuracy = correct_count / total_count if total_count > 0 else 0 @@ -345,23 +296,18 @@ class InfiniteMathEnv(BaseEnv): f"Level {current_level} accuracy: {accuracy:.2f} ({correct_count}/{total_count})" ) - # Record metrics for the current level self.eval_metrics.append((f"eval/level_{current_level}_accuracy", accuracy)) self.eval_metrics.append(("eval/current_level", current_level)) - # Record the actual evaluation results in the curriculum's performance history for i, (_, is_correct) in enumerate(results): if i < len(eval_generator_ids): - # Record the actual result self.curriculum.record_performance(eval_generator_ids[i], is_correct) else: - # Fallback if somehow the lists are different lengths sample_generator_id = random.choice( self.curriculum.DIFFICULTY_LEVELS[current_level] ) self.curriculum.record_performance(sample_generator_id, is_correct) - # Try to advance to the next level advanced = self.curriculum.advance_difficulty() new_level = self.curriculum.get_current_level() @@ -381,74 +327,62 @@ class InfiniteMathEnv(BaseEnv): try: logger.debug(f"Evaluating level {level} problem: {problem[:30]}...") - # Convert messages to a single prompt using the tokenizer messages = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": problem}, ] prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) - # Add prefilled thinking starter prefill = "\n\n" prefilled_prompt = prompt + prefill - # Generate completion using the prompt logger.debug(f"Requesting completion for problem: {problem[:30]}...") completion = await self.server.completion( prompt=prefilled_prompt, n=1, max_tokens=self.config.max_token_length, - temperature=0.0, # Use 0 temperature for deterministic results + temperature=0.0, top_p=1.0, split="eval", ) - # Extract the completion text and prepend the thinking starter model_answer = prefill + ( completion.choices[0].text if hasattr(completion.choices[0], "text") else completion.choices[0].message.content ) - # Check if the answer is correct is_correct = self.check_answer(model_answer, solution) logger.debug(f"Problem evaluated: level={level}, correct={is_correct}") return level, is_correct except Exception as e: logger.error(f"Error evaluating problem: {e}") - # Return a failed result in case of error return level, False def check_answer(self, model_answer: str, solution: str) -> bool: """Check if the model's answer matches the solution.""" - # Extract the part after the thinking block after_think_part = ( model_answer.split("")[-1].strip() if "" in model_answer else model_answer ) - # Extract the boxed answer if present boxed_answer = self._extract_boxed_answer(after_think_part) if not boxed_answer: - # Try to find the answer in the last line lines = after_think_part.strip().split("\n") if lines: boxed_answer = lines[-1].strip() - # Clean up answers for comparison (remove spaces, convert to lowercase) model_clean = self._clean_for_comparison( boxed_answer if boxed_answer else after_think_part ) solution_clean = self._clean_for_comparison(solution) - # Check if they match return model_clean == solution_clean def _extract_boxed_answer(self, text: str) -> Optional[str]: """Extract answer from a LaTeX boxed expression.""" - # Try to find boxed content boxed_match = re.search(r"\\boxed{([^}]*)}", text) if boxed_match: return boxed_match.group(1) @@ -456,7 +390,6 @@ class InfiniteMathEnv(BaseEnv): def _clean_for_comparison(self, text: str) -> str: """Clean text for comparison.""" - # Remove LaTeX commands, spaces, commas, and convert to lowercase cleaned = re.sub(r"\\[a-zA-Z]+", "", text) cleaned = re.sub(r"[,\s]", "", cleaned) cleaned = cleaned.lower() @@ -464,11 +397,8 @@ class InfiniteMathEnv(BaseEnv): async def collect_trajectories(self, item) -> Tuple[List, List]: """Collect trajectories for the current item.""" - # Extract information from the item problem_prompt, solution, generator_id = item - # Create prompt using tokenizer's chat template - # Add prefilled thinking starter prefill = "\n\n" messages = [ {"role": "system", "content": self.system_prompt}, @@ -477,7 +407,6 @@ class InfiniteMathEnv(BaseEnv): ] prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) - # Generate completions using completion API completions = await self.server.completion( prompt=prompt, n=self.config.group_size, @@ -486,19 +415,15 @@ class InfiniteMathEnv(BaseEnv): top_p=self.config.top_p, ) - # Prepare data for scoring to_score = [] - # Track level for metrics level = None for lvl, generator_ids in self.curriculum.DIFFICULTY_LEVELS.items(): if generator_id in generator_ids: level = lvl break - # Process each completion for i, completion in enumerate(completions.choices): - # Get the completion text and prepend the thinking starter model_answer = prefill + ( completion.text if hasattr(completion, "text") @@ -506,20 +431,14 @@ class InfiniteMathEnv(BaseEnv): ) print("model_answer", model_answer) - # Build complete message sequence full_messages = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": dict(problem_prompt[0])["content"]}, {"role": "assistant", "content": model_answer}, ] - # Add to scoring list to_score.append((full_messages, solution, generator_id, level)) - - # Record performance in curriculum for each item we're scoring - # This will be called again after scoring, but that's fine - # No additional items for backlog backlog = [] return to_score, backlog @@ -532,57 +451,45 @@ class InfiniteMathEnv(BaseEnv): scored_data["scores"] = [] scored_data["messages"] = [] - # Process each item in the rollout data for i, (messages, solution, generator_id, level) in enumerate(rollout_group_data): model_answer = messages[-1]["content"] current_score = 0.0 - # 1. Accuracy Check is_correct = self.check_answer(model_answer, solution) if is_correct: current_score += self.config.correct_reward else: current_score += self.config.incorrect_reward - # Record answer correctness for tracking and curriculum self.percent_correct_buffer.append(1 if is_correct else 0) if level is not None: self.level_correct_buffer[level].append(1 if is_correct else 0) self.curriculum.record_performance(generator_id, is_correct) - # 2. Thinking Block Check think_match = re.search(r"(.*?)", model_answer, re.DOTALL) if think_match: think_content = think_match.group(1).strip() - if think_content: # Check if there's actual content + if think_content: current_score += self.config.think_block_bonus - # else: penalty for empty think block, or neutral - # else: penalty for missing think block, or neutral - # 3. Boxed Answer Check - # Extract the part after the thinking block for boxed answer validation after_think_part = model_answer.split("")[-1].strip() if "" in model_answer else model_answer boxed_answer_content = self._extract_boxed_answer(after_think_part) - if boxed_answer_content is not None: # Check if \boxed{} is present and has content + if boxed_answer_content is not None: current_score += self.config.boxed_answer_bonus - # else: penalty for missing/malformed boxed answer, or neutral logger.info(f"Item {i}: Correct: {is_correct}, Think Bonus: {self.config.think_block_bonus if think_match and think_match.group(1).strip() else 0}, Boxed Bonus: {self.config.boxed_answer_bonus if boxed_answer_content is not None else 0}, Final Score: {current_score}") - # Tokenize for the trainer tokens_dict = tokenize_for_trainer( self.tokenizer, - messages, # These are the full messages including system, user, assistant - None, # Not used by this tokenizer function apparently + messages, + None, ) - # Add to scored data scored_data["tokens"].append(tokens_dict["tokens"]) scored_data["masks"].append(tokens_dict["masks"]) scored_data["scores"].append(current_score) scored_data["messages"].append(messages) - # Advance difficulty if criteria met self.curriculum.advance_difficulty() return scored_data @@ -605,7 +512,6 @@ if __name__ == "__main__": inference_weight=1.0, wandb_name="infinite_math", data_path_to_save_groups="data/infinite_math_groups.jsonl", - # InfiniteMath specific config starting_level=1, progress_threshold=0.8, min_evaluations=10, @@ -613,10 +519,8 @@ if __name__ == "__main__": incorrect_reward=-0.5, apply_length_penalty=True, length_threshold_ratio=0.6, - # Completion parameters temperature=0.7, top_p=0.9, - # Reward function configuration - use name directly reward_functions=["accuracy", "format", "boxed"], accuracy_reward_weight=1.0, format_reward_weight=0.2,