import json import os import random import re import uuid from typing import Dict, List, Optional, Tuple import wandb from datasets import Dataset, load_dataset from langdetect import LangDetectException, detect from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, Item, ScoredDataGroup, ) # System prompt can be reused or adapted for instruction following tasks system_prompt = ( "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " "solution prior to answering. You should enclose your thoughts and internal monologue inside " " tags, and then provide your solution or response to the problem." ) class IFConfig(BaseEnvConfig): dataset_name: str = Field("allenai/RLVR-IFeval", description="Default dataset name") dataset_config_name: Optional[str] = Field( None, description="Dataset config name, if any" ) test_set_ratio: float = Field( 0.05, description="The ratio of the selected dataset for testing" ) dump_rollouts: bool = Field( False, description="Whether to dump successful rollouts to JSONL files" ) dump_failed_rollouts: bool = Field( False, description="Whether to dump failed rollouts (all 0 scores) to JSONL files for debugging", ) rollout_save_score_threshold: float = Field( 0.7, description="Minimum score threshold for saving rollouts to data dumps" ) max_group_average_for_training: float = Field( 0.75, description="Maximum group average score to use for training (skip groups that are too easy)", ) dataset_shuffle_seed: int = Field( 42, description="Seed for shuffling the dataset during setup" ) resume_from_unsolved_dataset: Optional[str] = Field( None, description="Path to a remaining_unsolved.jsonl file to resume training from specific unsolved items", ) suppress_base_env_logs: bool = Field( default=True, description="Suppress verbose base environment logs (like status dict updates).", ) solve_on_single_correct: bool = Field( default=False, description="Mark item as solved if even one rollout in the group gets it correct (removes from circulation)", ) class InstructionFollowingEnv(BaseEnv): env_config_cls = IFConfig def __init__( self, config: IFConfig, server_configs: List[APIServerConfig], slurm=True, testing=False, ): super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer = list() self.eval_metrics = list() self.rollouts_for_wandb = [] # Data dumping infrastructure self.rollouts_to_save_buffer = [] self.failed_rollouts_to_save_buffer = [] self.run_uuid = str(uuid.uuid4())[:8] self.save_file_batch_num = 0 self.failed_save_file_batch_num = 0 # Adaptive curriculum: cycling queue for unsolved items self.active_train_queue = [] # Items currently in circulation self.solved_items = [] # Items that have been solved (removed from circulation) self.item_attempt_counts = ( {} ) # Track how many times each item has been attempted # Create data dumps directory self.datadumps_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "data_dumps" ) # Create datasets directory for curriculum state dumps self.datasets_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "datasets" ) # Validate configuration for potential conflicts self._validate_config() # Configure logging suppression if self.config.suppress_base_env_logs: import logging # Suppress specific loggers that are too verbose logging.getLogger("atroposlib.envs.base").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) @classmethod def config_init( self, ) -> Tuple[IFConfig, List[APIServerConfig]]: # Configuration for the Instruction Following Environment env_config = IFConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=32, use_wandb=True, rollout_server_url="http://localhost:8000", total_steps=500, batch_size=1024, steps_per_eval=20, max_token_length=1024 * 15, inference_weight=1.0, wandb_name="instruction_following_rlvr_ifeval", # Specific WandB project name eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, dataset_name="allenai/RLVR-IFeval", # Default dataset dataset_config_name=None, # RLVR-IFeval doesn't have a specific config name, uses 'default' test_set_ratio=0.05, # The ratio of the selelcted dataset in % dump_rollouts=False, # Enable data dumping if needed dump_failed_rollouts=False, # Enable failed rollout dumping for debugging rollout_save_score_threshold=0.7, # Save rollouts with score >= 0.7 max_group_average_for_training=0.75, # Skip groups that are too easy for training dataset_shuffle_seed=42, # Seed for dataset shuffling resume_from_unsolved_dataset=None, # Path to resume from unsolved items suppress_base_env_logs=True, # Suppress verbose base environment logs solve_on_single_correct=False, # Mark item as solved if any rollout gets it correct ) # Server configurations can be similar to SingleToolCallingEnv or adjusted server_configs = [ APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", num_max_requests_at_once=32, num_requests_for_eval=256, ) ] return env_config, server_configs async def create_rollout_table(self, wandb_metrics): # Logs rollouts to a WandB table for visualization if len(self.rollouts_for_wandb) > 0: table = wandb.Table(columns=["text", "score", "constraint_details"]) for group in self.rollouts_for_wandb: for item in group: # item[0] is model output, item[1] is score, item[2] is constraint info table.add_data(item[0], item[1], json.dumps(item[2])) wandb_metrics["train/rollouts"] = table self.rollouts_for_wandb = [] return wandb_metrics async def wandb_log(self, wandb_metrics: Optional[Dict] = None): # Logs metrics to WandB if wandb_metrics is None: wandb_metrics = dict() try: wandb_metrics["train/percent_correct"] = sum( self.percent_correct_buffer ) / len(self.percent_correct_buffer) except ZeroDivisionError: pass # Buffer might be empty # Add adaptive curriculum metrics total_items = len(self.active_train_queue) + len(self.solved_items) if total_items > 0: wandb_metrics["curriculum/active_items"] = len(self.active_train_queue) wandb_metrics["curriculum/solved_items"] = len(self.solved_items) wandb_metrics["curriculum/percent_solved"] = ( len(self.solved_items) / total_items ) wandb_metrics["curriculum/total_items"] = total_items # Average attempt count for items still in circulation if self.item_attempt_counts: active_attempts = [ count for item_id, count in self.item_attempt_counts.items() if any( f"{item['func_name']}_{hash(str(item)) % 100000}" == item_id for item in self.active_train_queue ) ] if active_attempts: wandb_metrics["curriculum/avg_attempts_active"] = sum( active_attempts ) / len(active_attempts) self.percent_correct_buffer = list() for item in self.eval_metrics: wandb_metrics[item[0]] = item[1] self.eval_metrics = list() await super().wandb_log(wandb_metrics) async def setup(self): """ Load and preprocess the dataset for instruction following. This method is specifically tailored to process 'allenai/RLVR-IFeval' dataset structure. Each item from RLVR-IFeval is expected to have: - 'messages': A list of dictionaries, e.g., [{'role': 'user', 'content': 'instruction...'}] - 'ground_truth': A JSON string containing 'func_name' and arguments for the verifier. The method will parse these to produce items for the environment with: - 'prompt': The user's instruction string. - 'func_name': The string name of the verifier function. - 'args': A dictionary of arguments for that verifier function. """ # noqa: E501 dataset_name = getattr(self.config, "dataset_name", "allenai/RLVR-IFeval") dataset_config_name = getattr( self.config, "dataset_config_name", None ) # Default is None, RLVR-IFeval has no sub-config processed_items = [] try: print( f"Attempting to load dataset: {dataset_name}, " f"config: {dataset_config_name if dataset_config_name else 'default'}" ) if dataset_config_name: full_dataset_raw = load_dataset( dataset_name, dataset_config_name, split="train", trust_remote_code=True, ) else: full_dataset_raw = load_dataset( dataset_name, split="train", trust_remote_code=True ) print( f"Successfully loaded raw dataset. Number of items: {len(full_dataset_raw)}" ) for i, item in enumerate(full_dataset_raw): # Extract prompt from 'messages' field item_messages = item.get("messages") if ( not item_messages or not isinstance(item_messages, list) or len(item_messages) == 0 ): print( f"Warning: Item {i} has invalid or empty 'messages' field. Skipping. Item: {item}" ) continue # Assuming the relevant prompt is the content of the first message in the list # (or last, if multiple user messages were possible, but IFEval is typically single user instruction) prompt_text = item_messages[0].get("content") if not prompt_text: print( f"Warning: Item {i} '{item_messages[0]}' has no content. Skipping." ) continue # Get the ground_truth JSON string ground_truth_json_str = item.get("ground_truth") if not ground_truth_json_str or not isinstance( ground_truth_json_str, str ): print( f"Warning: Item {i} missing or has invalid 'ground_truth' string. Skipping. " f"Prompt: {prompt_text[:50]}..." ) continue try: parsed_gt = json.loads(ground_truth_json_str) if not isinstance(parsed_gt, dict): raise ValueError("Parsed ground_truth is not a dictionary.") except (json.JSONDecodeError, ValueError) as e: print( f"Warning: Could not parse 'ground_truth' JSON for item {i}. Error: {e}. " f"GT String: '{ground_truth_json_str}'. Prompt: {prompt_text[:50]}... Skipping." ) continue func_name_from_gt = parsed_gt.get("func_name") if not func_name_from_gt: print( f"Warning: Item {i} parsed 'ground_truth' has no 'func_name'. GT: {parsed_gt}. " f"Prompt: {prompt_text[:50]}... Skipping." ) continue if func_name_from_gt not in IF_FUNCTIONS_MAP: print( f"Warning: func_name '{func_name_from_gt}' in item {i} not in IF_FUNCTIONS_MAP. " f"Prompt: {prompt_text[:50]}... Skipping." ) continue # Prepare args for the verifier function: remove func_name and keep others. # Verifier functions will only use args they expect. args_dict = { k: v for k, v in parsed_gt.items() if k != "func_name" and v is not None } processed_items.append( { "prompt": prompt_text, "func_name": func_name_from_gt, "args": args_dict, "original_constraints_for_logging": str( item.get("constraint", "") ), # For logging, from RLVR-IFeval structure "expected_response_for_logging": "", } ) if not processed_items: print( "Warning: No items successfully processed from the dataset. " "Check dataset format/content or parsing logic." ) raise ValueError( "Dataset processing resulted in no valid items for RLVR-IFeval. Cannot proceed without data." ) full_dataset = Dataset.from_list(processed_items) print( f"Successfully processed {len(full_dataset)} items from dataset '{dataset_name}'." ) except Exception as e: # This block is a fallback if the primary dataset loading/processing fails catastrophically. # For RLVR-IFeval, a failure here suggests issues with Hugging Face access, # dataset integrity, or fundamental code errors. print( f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. " f"Using a DUMMY dataset as fallback." ) dummy_data_for_fallback = [ { "prompt": "Dummy Instruction 1: Ensure your response contains the word 'example'.", "func_name": "verify_keywords", "args": {"keyword_list": ["example"]}, "original_constraints_for_logging": "Contains 'example'", "expected_response_for_logging": "This is an example response.", }, { "prompt": "Dummy Instruction 2: Output a valid JSON with key 'data' and value 'test'.", "func_name": "validate_json_format", "args": {}, "original_constraints_for_logging": "Output valid JSON.", "expected_response_for_logging": '{\\"data\\": \\"test\\"}', }, ] full_dataset = Dataset.from_list(dummy_data_for_fallback) print( f"Initialized with DUMMY dataset of {len(full_dataset)} items " f"due to previous errors." ) full_dataset = full_dataset.shuffle(seed=self.config.dataset_shuffle_seed) actual_test_size = self.config.test_set_ratio # Read from config num_items = len(full_dataset) if num_items == 0: print("ERROR: Dataset is empty. Cannot create train/test split.") self.train = Dataset.from_list([]) self.test = Dataset.from_list([]) elif num_items == 1: print("Warning: Dataset has only 1 item. Using it for both train and test.") self.train = full_dataset self.test = full_dataset else: # num_items > 1 # Ensure test_size results in at least 1 item for test set if possible, but not more than train set if num_items < 5: # For 2,3,4 items, make test size 1 min_test_items = 1 else: # For 5+ items, 20% is fine min_test_items = max(1, int(num_items * actual_test_size)) # Ensure test split is not too large, e.g. not more than 50% unless dataset is very small # And ensure train always has at least one item if num_items > 1 calculated_test_size = min_test_items / num_items if ( calculated_test_size >= 0.5 and num_items > 2 ): # If test is 50% or more and we have 3+ items calculated_test_size = ( num_items - 1 ) / num_items # Make train have at least 1 split_dataset = full_dataset.train_test_split( test_size=calculated_test_size, seed=42 ) self.train = split_dataset["train"] self.test = split_dataset["test"] # Final check for empty train/test after split, should not happen with logic above if num_items > 0 if len(self.train) == 0 and len(self.test) > 0: print( "Warning: Train set empty after split, test set has data. " "This is unusual. Swapping." ) self.train = self.test # Fallback, though indicates issue elif len(self.test) == 0 and len(self.train) > 0: print( "Warning: Test set empty after split, train set has data. " "Using full train set for test as well." ) self.test = self.train self.iter = 0 # Initialize the adaptive curriculum queue if self.config.resume_from_unsolved_dataset: print( f"🔄 Resume mode: Loading unsolved items from {self.config.resume_from_unsolved_dataset}" ) print( f" Note: This will override the dataset_name '{self.config.dataset_name}' for training items" ) print(f" Test set will still use items from '{self.config.dataset_name}'") await self._load_from_unsolved_dataset() else: # Initialize with all training items self.active_train_queue = list(self.train) self.solved_items = [] self.item_attempt_counts = {} print( f"Dataset setup complete. Train size: {len(self.train)}, Test size: {len(self.test)}" ) print( f"Adaptive curriculum initialized with {len(self.active_train_queue)} items in active queue" ) def _validate_config(self): """Validate configuration for potential conflicts and warn user.""" if self.config.resume_from_unsolved_dataset and self.config.dataset_name: print("⚠️ Configuration Notice:") print( f" Both 'dataset_name' ({self.config.dataset_name}) and 'resume_from_unsolved_dataset' are set" ) print(" Behavior:") print( " - Training items: Will come from the resume file (overrides dataset_name)" ) print(" - Test/eval items: Will come from dataset_name") print( " - This is useful for resuming training while keeping consistent evaluation" ) print() async def _get_score_from_verifier( self, model_response_text: str, func_name: str, args: Dict ) -> float: """Helper to call verifier function and get a numerical score. Also enforces strict ... formatting. """ # 1. Count and tags num_think_open = len(re.findall(r"", model_response_text, re.IGNORECASE)) num_think_close = len( re.findall(r"", model_response_text, re.IGNORECASE) ) if not (num_think_open == 1 and num_think_close == 1): return 0.0 # 3. Find the first occurrence of and try: think_open_match = re.search(r"", model_response_text, re.IGNORECASE) think_close_match = re.search( r"", model_response_text, re.IGNORECASE ) # These should exist due to the count check, but access .start() and .end() safely idx_think_open = think_open_match.start() idx_think_close_start = think_close_match.start() idx_think_close_end = think_close_match.end() except AttributeError: return 0.0 # 4. If appears after , malformed. if idx_think_open >= idx_think_close_start: # print(f"DEBUG: tag appears at or after tag. Response: '{model_response_text[:200]}...'") return 0.0 # 5. Extract text_to_verify (content after the first ) text_to_verify = model_response_text[idx_think_close_end:].strip() # 6. Check if text_to_verify itself contains any further or tags. if re.search(r"", text_to_verify, re.IGNORECASE) or re.search( r"", text_to_verify, re.IGNORECASE ): return 0.0 # If all checks pass, proceed with verification using text_to_verify if func_name not in IF_FUNCTIONS_MAP: print( f"Warning: Verifier function '{func_name}' not found in IF_FUNCTIONS_MAP." ) return 0.0 verifier_func = IF_FUNCTIONS_MAP[func_name] raw_score = None try: if func_name == "validate_placeholders": raw_score = verifier_func(text_to_verify, N=args.get("N")) elif func_name == "verify_bullet_points": raw_score = verifier_func(text_to_verify, N=args.get("N")) elif func_name == "validate_repeat_prompt": raw_score = verifier_func( text_to_verify, args.get("original_prompt", "") ) else: from inspect import signature sig = signature(verifier_func) valid_params = [p for p in sig.parameters if p != "text"] filtered_args = { k: args[k] for k in valid_params if k in args and args[k] is not None } raw_score = verifier_func(text_to_verify, **filtered_args) except LangDetectException: print( f"Warning: langdetect failed for func_name '{func_name}'. Scoring as incorrect." ) return 0.0 except ImportError as e: print( f"Warning: ImportError during verifier function '{func_name}': {e}. Check dependencies." ) return 0.0 except TypeError as e: print( f"TypeError calling {func_name} with args {args}: {e}. Text: '{text_to_verify[:100]}...'" ) return 0.0 except Exception as e: print( f"Unexpected error in verifier function '{func_name}' with args {args}: {e}" ) return 0.0 if isinstance(raw_score, tuple): score_value = float(raw_score[0]) elif isinstance(raw_score, bool): score_value = float(raw_score) else: print( f"Warning: Verifier '{func_name}' returned unexpected type: {type(raw_score)}. Expected bool or tuple." ) score_value = 0.0 return score_value async def rollout_and_score_eval(self, test_item: Dict): # test_item is a dictionary from the test set, processed by setup() # It should contain 'prompt', 'func_name', 'args' instruction_prompt_text = test_item["prompt"] func_name = test_item["func_name"] args_for_verifier = test_item["args"] messages = [{"role": "system", "content": system_prompt}] messages.append({"role": "user", "content": instruction_prompt_text}) prompt_str = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) async with self.server.managed_server(tokenizer=self.tokenizer) as managed: completion = await managed.completion( prompt=prompt_str, n=1, max_tokens=self.config.max_token_length, # Use config for max_tokens temperature=0.2, # Temperature for eval, can be 0 for deterministic split="eval", ) model_response_text = completion.choices[0].text score_value = await self._get_score_from_verifier( model_response_text, func_name, args_for_verifier ) return ( score_value # Returns 1.0 for correct, 0.0 for incorrect based on verifier ) async def evaluate(self, *args, **kwargs): # Evaluates the model on the test set if not self.test or len(self.test) == 0: print("Warning: Test set is empty. Skipping evaluation.") self.eval_metrics.append(("eval/percent_correct", 0.0)) return print(f"Starting evaluation on {len(self.test)} items...") eval_tasks = [] for test_item_dict in self.test: # self.test contains dicts after setup eval_tasks.append(self.rollout_and_score_eval(test_item_dict)) scores = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating") if not scores: # If gather returns empty list percent_correct = 0.0 else: percent_correct = sum(scores) / len(scores) self.eval_metrics.append(("eval/percent_correct", percent_correct)) print(f"Evaluation finished. Percent correct: {percent_correct:.4f}") async def collect_trajectories( self, item: Item ) -> Tuple[Optional[ScoredDataGroup], List]: # item = (prompt_messages_tuple, answer_info_dict) # answer_info_dict = {"func_name": ..., "args": ...} prompt_messages_list = [dict(msg_fset) for msg_fset in item[0]] answer_info = item[1] prompt_str = self.tokenizer.apply_chat_template( prompt_messages_list, add_generation_prompt=True, tokenize=False ) try: async with self.server.managed_server(tokenizer=self.tokenizer) as managed: completions = await managed.completion( prompt=prompt_str, n=self.config.group_size, max_tokens=self.config.max_token_length, temperature=0.8, # Temperature for diverse responses during training rollouts ) state = managed.get_state() nodes = state["nodes"] except Exception as e: print(f"ERROR: Exception during completion generation: {e}") return None, [] to_score_list = [] for i, choice in enumerate(completions.choices): trajectory_messages = [dict(msg_fset) for msg_fset in item[0]] # Fresh copy trajectory_messages.append({"role": "assistant", "content": choice.text}) to_score_list.append( { "messages": tuple(trajectory_messages), "answer_info": answer_info, "tokens": nodes[i].tokens, "masks": nodes[i].masked_tokens, "logprobs": nodes[i].logprobs, } ) # Pass answer_info if not to_score_list: return None, [] scored_data = await self.score(to_score_list) # Handle adaptive curriculum: decide whether to keep item in circulation if scored_data and scored_data.get("scores"): group_average_score = sum(scored_data["scores"]) / len( scored_data["scores"] ) self._handle_item_result(item, group_average_score, scored_data["scores"]) elif scored_data is None: # If scored_data is None, it might be because the group was skipped for being too easy # We need to calculate the scores ourselves to handle the item properly temp_scores = [] for rollout_item in to_score_list: trajectory_messages = rollout_item["messages"] answer_info = rollout_item["answer_info"] model_response_text = trajectory_messages[-1]["content"] func_name = answer_info["func_name"] args_for_verifier = answer_info["args"] # Get score (1.0 for correct, 0.0 for incorrect from verifier) score_value = await self._get_score_from_verifier( model_response_text, func_name, args_for_verifier ) reward = 1.0 if score_value == 1.0 else 0 temp_scores.append(reward) if temp_scores: group_average_score = sum(temp_scores) / len(temp_scores) self._handle_item_result(item, group_average_score, temp_scores) to_backlog = [] # Backlog not currently used but part of signature return scored_data, to_backlog def save_checkpoint(self, step, data=None): if data is None: data = {} data["iter"] = self.iter data["save_file_batch_num"] = self.save_file_batch_num data["failed_save_file_batch_num"] = self.failed_save_file_batch_num # Save adaptive curriculum state data["active_train_queue"] = self.active_train_queue data["solved_items"] = self.solved_items data["item_attempt_counts"] = self.item_attempt_counts super().save_checkpoint(step, data) async def close(self): """Save any remaining rollouts and curriculum state before closing.""" if self.config.dump_rollouts and self.rollouts_to_save_buffer: print( f"Saving {len(self.rollouts_to_save_buffer)} remaining rollouts before closing..." ) await self._save_rollouts_to_jsonl() if self.config.dump_failed_rollouts and self.failed_rollouts_to_save_buffer: print( f"Saving {len(self.failed_rollouts_to_save_buffer)} remaining failed rollouts before closing..." ) await self._save_failed_rollouts_to_jsonl() # Save final curriculum state if self.active_train_queue: print( f"Saving final curriculum state with {len(self.active_train_queue)} unsolved items..." ) await self._dump_active_queue_dataset() await super().close() async def score(self, rollout_group_data: List[Dict]) -> Optional[ScoredDataGroup]: # rollout_group_data is a list of dicts with messages, answer_info, tokens, masks, logprobs # answer_info_dict = {"func_name": ..., "args": ...} scores_container = ScoredDataGroup() scores_container["tokens"] = list() scores_container["masks"] = list() scores_container["scores"] = list() scores_container["inference_logprobs"] = list() if not rollout_group_data: return None # The 'answer_info' (func_name, args) is consistent for all items in this group, # as it comes from the same initial prompt. # We can extract it once if needed, but it's passed per item. random.shuffle(rollout_group_data) # Shuffle to avoid bias # Data dumping: collect rollouts for saving (group format) rollouts_for_this_group = [] failed_rollouts_for_this_group = [] for trajectory_item in rollout_group_data: full_trajectory_messages = trajectory_item["messages"] answer_info = trajectory_item[ "answer_info" ] # {"func_name": ..., "args": ...} model_response_text = full_trajectory_messages[-1]["content"] func_name = answer_info["func_name"] args_for_verifier = answer_info["args"] # Get score (1.0 for correct, 0.0 for incorrect from verifier) score_value = await self._get_score_from_verifier( model_response_text, func_name, args_for_verifier ) # Map to reward: 1.0 for correct, 0 for incorrect reward = 1.0 if score_value == 1.0 else 0 # Prepare structured conversation for data dumping conversation = [dict(msg) for msg in full_trajectory_messages] # Create rollout dict for this specific rollout in the group rollout_dict = { "conversation": conversation, "score": reward, } # Collect rollouts for this group based on score and config if ( self.config.dump_rollouts and reward >= self.config.rollout_save_score_threshold ): rollouts_for_this_group.append(rollout_dict) elif self.config.dump_failed_rollouts and reward == 0: failed_rollouts_for_this_group.append(rollout_dict) tokens = trajectory_item["tokens"] masks = trajectory_item["masks"] logprobs = trajectory_item["logprobs"] # Filter out examples with insufficient context (too short) if ( sum(1 for m_val in masks if m_val != -100) < 10 ): # At least 10 non-masked tokens continue scores_container["tokens"].append(tokens) scores_container["masks"].append(masks) scores_container["inference_logprobs"].append(logprobs) scores_container["scores"].append(reward) # Stop if we have enough examples for the group if len(scores_container["tokens"]) >= self.config.group_size: break if not scores_container["tokens"]: # No valid items collected return None # Calculate group average score for difficulty filtering and logging current_scores = scores_container.get("scores", []) if current_scores: average_score = sum(current_scores) / len(current_scores) # Get task info from the first rollout's answer_info answer_info = ( rollout_group_data[0]["answer_info"] if rollout_group_data else {} ) func_name = answer_info.get("func_name", "unknown_task") # Check if group is too easy for training (but still allow data dumping) if average_score > self.config.max_group_average_for_training: print( f"Task: {func_name} | Group average score: {average_score:.4f} (SKIPPED - too easy for training, threshold: {self.config.max_group_average_for_training})" # noqa ) # Still handle data dumping for groups that are too easy for training # but might be useful for analysis if ( rollouts_for_this_group and average_score <= self.config.max_group_average_for_training + 0.1 ): # Small buffer for data collection # Extract item info for the group - get from first rollout's answer_info answer_info = rollout_group_data[0]["answer_info"] item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa group_data_to_save = { "item_id": item_id, "rollouts": rollouts_for_this_group, "constraint_details": answer_info, # Store group-level metadata "group_average_score": average_score, # Add group average for analysis "skipped_for_training": True, # Mark as skipped for training } self.rollouts_to_save_buffer.append(group_data_to_save) if failed_rollouts_for_this_group: # Extract item info for the failed group answer_info = rollout_group_data[0]["answer_info"] item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa failed_group_data_to_save = { "item_id": item_id, "rollouts": failed_rollouts_for_this_group, "constraint_details": answer_info, # Store group-level metadata "group_average_score": average_score, # Add group average for analysis "skipped_for_training": True, # Mark as skipped for training } self.failed_rollouts_to_save_buffer.append( failed_group_data_to_save ) # Save rollouts if buffer is getting large (batch processing) if ( self.config.dump_rollouts and len(self.rollouts_to_save_buffer) >= 100 ): await self._save_rollouts_to_jsonl() if ( self.config.dump_failed_rollouts and len(self.failed_rollouts_to_save_buffer) >= 50 ): await self._save_failed_rollouts_to_jsonl() return None # Skip this group for training log_message = ( f"Task: {func_name} | Group average score: {average_score:.4f}" ) if all(s >= 0.5 for s in current_scores): print(f"{log_message} (All correct in this group!)") elif all(s == 0.0 for s in current_scores): print(f"{log_message} (All failed - format/constraint violations!)") elif all(s < 0.5 for s in current_scores): print(f"{log_message} (All incorrect but some partial credit!)") else: print(log_message) # Create group data structure and add to buffers for data dumping (for training groups) if rollouts_for_this_group: # Extract item info for the group - get from first rollout's answer_info answer_info = rollout_group_data[0]["answer_info"] item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa group_data_to_save = { "item_id": item_id, "rollouts": rollouts_for_this_group, "constraint_details": answer_info, # Store group-level metadata "group_average_score": ( current_scores[0] if len(current_scores) == 1 else sum(current_scores) / len(current_scores) ), # Add group average for analysis "skipped_for_training": False, # Mark as used for training } self.rollouts_to_save_buffer.append(group_data_to_save) if failed_rollouts_for_this_group: # Extract item info for the failed group answer_info = rollout_group_data[0]["answer_info"] item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa failed_group_data_to_save = { "item_id": item_id, "rollouts": failed_rollouts_for_this_group, "constraint_details": answer_info, # Store group-level metadata "group_average_score": ( current_scores[0] if len(current_scores) == 1 else sum(current_scores) / len(current_scores) ), # Add group average for analysis "skipped_for_training": False, # Mark as used for training } self.failed_rollouts_to_save_buffer.append(failed_group_data_to_save) # Save rollouts if buffer is getting large (batch processing) if self.config.dump_rollouts and len(self.rollouts_to_save_buffer) >= 100: await self._save_rollouts_to_jsonl() if ( self.config.dump_failed_rollouts and len(self.failed_rollouts_to_save_buffer) >= 50 ): await self._save_failed_rollouts_to_jsonl() # Record success rate for logging (based on positive rewards) for rwd in scores_container["scores"]: self.percent_correct_buffer.append( max(0, rwd) ) # If reward is 1.0, it's a success # Optional: Apply length penalty if all responses are correct (reward 1.0) # This logic is from SingleToolCallingEnv, may need adjustment for IF if all(s == 1.0 for s in scores_container["scores"]): token_lengths = [len(t) for t in scores_container["tokens"]] if not token_lengths or max(token_lengths) == 0: return scores_container # Avoid division by zero, or if all empty max_allowed_length = self.config.max_token_length # Threshold can be adjusted, e.g., 75% of max_token_length length_threshold = max_allowed_length * 0.75 penalized_scores = [] for i, length in enumerate(token_lengths): original_score = scores_container["scores"][i] # Should be 1.0 here if length <= length_threshold: penalized_scores.append(original_score) else: # Linear penalty for exceeding threshold penalty_factor = (length - length_threshold) / ( max_allowed_length - length_threshold ) penalty_factor = min(penalty_factor, 1.0) # Cap penalty factor at 1 # Penalized score scales from original_score down to original_score * (1-1) = 0 penalized_scores.append(original_score * (1.0 - penalty_factor)) scores_container["scores"] = penalized_scores # If all scores are identical after potential penalties, no learning signal if ( len(set(scores_container["scores"])) <= 1 and len(scores_container["scores"]) > 1 ): return None # Avoid sending data with no variance return scores_container async def get_next_item(self) -> Item: # Fetches the next item from the adaptive curriculum queue if not self.active_train_queue: # If active queue is empty, check if we have any items left if not self.solved_items: print("Error: No training data available in get_next_item.") # Return a dummy item to prevent crashes dummy_prompt_messages = ( frozenset({"role": "system", "content": system_prompt}.items()), frozenset( { "role": "user", "content": "Dummy instruction: say hello.", }.items() ), ) dummy_answer_info = { "func_name": "verify_keywords", "args": {"keyword_list": ["hello"]}, } return (dummy_prompt_messages, dummy_answer_info) else: # All items have been solved! Reset the queue with solved items for continued training print( f"🎉 All {len(self.solved_items)} items have been solved! Resetting queue for continued training..." ) self.active_train_queue = list(self.solved_items) self.solved_items = [] # Reset attempt counts for the new cycle self.item_attempt_counts = {} # Get the next item from the front of the active queue raw_item = self.active_train_queue.pop(0) self.iter += 1 # Create a unique identifier for this item for tracking item_id = f"{raw_item['func_name']}_{hash(str(raw_item)) % 100000}" # Track attempt count if item_id not in self.item_attempt_counts: self.item_attempt_counts[item_id] = 0 self.item_attempt_counts[item_id] += 1 instruction_prompt_text = raw_item["prompt"] # Construct messages for the LLM (prompt tuple part of Item) # Using frozenset as required by BaseEnv's Item type hint prompt_messages_tuple = ( frozenset({"role": "system", "content": system_prompt}.items()), frozenset({"role": "user", "content": instruction_prompt_text}.items()), ) # The "answer" part for scoring purposes (answer_info dict part of Item) answer_info = { "func_name": raw_item["func_name"], "args": raw_item["args"], # Add item tracking info "item_id": item_id, "raw_item": raw_item, # Store the full item for queue management "attempt_count": self.item_attempt_counts[item_id], # Optionally include other info for logging/debugging if needed from raw_item "original_constraints_for_logging": raw_item.get( "original_constraints", "" ), "expected_response_for_logging": raw_item.get( "expected_response_for_logging", "" ), } # Dump active queue every 100 iterations for resumability if self.iter % 100 == 0 and self.iter > 0: await self._dump_active_queue_dataset() return (prompt_messages_tuple, answer_info) async def _dump_active_queue_dataset(self): """ Dumps the current active queue to a JSONL file so training can be resumed from the unsolved items if the environment is shut down. """ if not self.active_train_queue: print("No active items to dump - all items have been solved!") return try: if not os.path.exists(self.datasets_dir): os.makedirs(self.datasets_dir) print(f"Created datasets directory: {self.datasets_dir}") except Exception as e: print(f"Error creating datasets directory {self.datasets_dir}: {e}") return filename = os.path.join(self.datasets_dir, "remaining_unsolved.jsonl") try: with open(filename, "w", encoding="utf-8") as f: for item in self.active_train_queue: # Add metadata about the current state item_with_metadata = dict(item) item_id = f"{item['func_name']}_{hash(str(item)) % 100000}" item_with_metadata["_curriculum_metadata"] = { "item_id": item_id, "attempt_count": self.item_attempt_counts.get(item_id, 0), "queue_position": self.active_train_queue.index(item), "total_active": len(self.active_train_queue), "total_solved": len(self.solved_items), "iteration_dumped": self.iter, } json.dump(item_with_metadata, f, ensure_ascii=False) f.write("\n") print( f"📁 Dumped {len(self.active_train_queue)} unsolved items to {filename} (iteration {self.iter})" ) print( f" Queue status: {len(self.active_train_queue)} active, {len(self.solved_items)} solved" ) except Exception as e: print(f"Error dumping active queue to {filename}: {e}") async def _load_from_unsolved_dataset(self): """ Load the active queue from a previously saved remaining_unsolved.jsonl file. This allows resuming training from where it left off. Note: When resuming, the training items come from the saved file, but the test set still comes from the current dataset_name configuration. """ try: with open( self.config.resume_from_unsolved_dataset, "r", encoding="utf-8" ) as f: loaded_items = [] loaded_attempt_counts = {} original_dataset_info = None for line in f: item_data = json.loads(line.strip()) # Extract curriculum metadata if present metadata = item_data.pop("_curriculum_metadata", {}) item_id = metadata.get("item_id") attempt_count = metadata.get("attempt_count", 0) # Store info about the original dataset for validation if original_dataset_info is None and "iteration_dumped" in metadata: original_dataset_info = { "total_active_at_dump": metadata.get("total_active"), "total_solved_at_dump": metadata.get("total_solved"), "iteration_dumped": metadata.get("iteration_dumped"), } if item_id and attempt_count > 0: loaded_attempt_counts[item_id] = attempt_count # Validate that the item has the expected structure required_fields = ["prompt", "func_name", "args"] if not all(field in item_data for field in required_fields): print( f"Warning: Skipping malformed item missing required fields: {list(item_data.keys())}" ) continue loaded_items.append(item_data) if not loaded_items: raise ValueError("No valid items found in resume file") self.active_train_queue = loaded_items self.solved_items = [] # Start with no solved items when resuming self.item_attempt_counts = loaded_attempt_counts print( f"📂 Loaded {len(loaded_items)} unsolved items from {self.config.resume_from_unsolved_dataset}" ) if loaded_attempt_counts: avg_attempts = sum(loaded_attempt_counts.values()) / len( loaded_attempt_counts ) print( f" Restored attempt counts for {len(loaded_attempt_counts)} items (avg: {avg_attempts:.1f} attempts)" # noqa ) if original_dataset_info: print( f" Original dump info: {original_dataset_info['total_solved_at_dump']} solved, " f"{original_dataset_info['total_active_at_dump']} active at iteration {original_dataset_info['iteration_dumped']}" # noqa ) # Validate compatibility with current dataset if hasattr(self, "train") and len(self.train) > 0: original_total = original_dataset_info.get( "total_active_at_dump", 0 ) + original_dataset_info.get("total_solved_at_dump", 0) current_total = len(self.train) if original_total != current_total: print("⚠️ Warning: Dataset size mismatch!") print( f" Original dataset had {original_total} items, current dataset has {current_total} items" # noqa ) print( " This might indicate different dataset versions or configurations" ) except FileNotFoundError: print( f"❌ Resume file not found: {self.config.resume_from_unsolved_dataset}" ) print("Falling back to full dataset initialization...") self.active_train_queue = list(self.train) self.solved_items = [] self.item_attempt_counts = {} except Exception as e: print(f"❌ Error loading from unsolved dataset: {e}") print("Falling back to full dataset initialization...") self.active_train_queue = list(self.train) self.solved_items = [] self.item_attempt_counts = {} def _handle_item_result( self, item: Item, group_average_score: float, group_scores: List[float] = None ): """ Handle the result of an item based on its group average score and individual scores. If solved (high score), remove from circulation. If not solved (low score), add back to the end of the queue. """ _, answer_info = item raw_item = answer_info.get("raw_item") item_id = answer_info.get("item_id") attempt_count = answer_info.get("attempt_count", 1) if not raw_item or not item_id: return # Skip if we don't have the necessary info # Define "solved" based on configuration options is_solved = False solve_reason = "" # Check if solved based on single correct rollout if self.config.solve_on_single_correct and group_scores: if any(score >= 1.0 for score in group_scores): is_solved = True solve_reason = " (single correct)" # Check if solved based on group average thresholds (original logic) if not is_solved: if group_average_score > self.config.max_group_average_for_training: is_solved = True solve_reason = " (too easy)" elif group_average_score >= 0.9: # Very high performance threshold is_solved = True solve_reason = " (mastered)" if is_solved: # Item is solved - move to solved items (remove from circulation) self.solved_items.append(raw_item) status = f"SOLVED - removed from circulation{solve_reason}" else: # Item not solved - add back to the end of the active queue self.active_train_queue.append(raw_item) status = "NOT SOLVED - added back to queue" # Log the decision periodically or for items with many attempts if attempt_count % 5 == 1 or attempt_count <= 3 or is_solved: func_name = answer_info.get("func_name", "unknown") print( f"📚 Item {item_id} (attempt #{attempt_count}): {func_name} | Score: {group_average_score:.4f} | {status}" # noqa ) print( f" Queue status: {len(self.active_train_queue)} active, {len(self.solved_items)} solved" ) async def add_rollouts_for_wandb( self, scored_data: ScoredDataGroup, # Assuming single ScoredDataGroup here item: Item = None, # item = (prompt_messages_tuple, answer_info_dict) ): # Saves rollouts for WandB logging num_keep = self.config.num_rollouts_per_group_for_logging if num_keep == -1: # Log all rollouts in the group num_keep = len(scored_data["tokens"]) # item[1] is the answer_info_dict containing func_name and args constraint_details_for_log = item[1] if item else {} rollout_batch = [] for i in range(min(num_keep, len(scored_data["tokens"]))): decoded_text = self.tokenizer.decode( scored_data["tokens"][i], skip_special_tokens=False ) score = scored_data["scores"][i] rollout_batch.append((decoded_text, score, constraint_details_for_log)) self.rollouts_for_wandb.append(rollout_batch) # Limit the number of rollout groups stored if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: self.rollouts_for_wandb.pop(0) async def _save_rollouts_to_jsonl(self): """Saves the buffered rollouts to a JSONL file in the datadumps directory.""" if not self.rollouts_to_save_buffer: print("Warning: _save_rollouts_to_jsonl called but buffer is empty!") return buffer_size = len(self.rollouts_to_save_buffer) print(f"Starting save of {buffer_size} rollout groups to JSONL file...") try: if not os.path.exists(self.datadumps_dir): os.makedirs(self.datadumps_dir) print(f"Created directory: {self.datadumps_dir}") except Exception as e: print(f"Error creating directory {self.datadumps_dir}: {e}") return filename = os.path.join( self.datadumps_dir, f"instruction_following_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl", ) try: with open(filename, "w", encoding="utf-8") as f: for rollout_dict in self.rollouts_to_save_buffer: json.dump(rollout_dict, f, ensure_ascii=False) f.write("\n") print(f"Successfully saved {buffer_size} rollout groups to {filename}") self.save_file_batch_num += 1 self.rollouts_to_save_buffer.clear() except Exception as e: print(f"Error saving rollouts to {filename}: {e}") async def _save_failed_rollouts_to_jsonl(self): """Saves the buffered failed rollouts to a JSONL file for debugging.""" if not self.failed_rollouts_to_save_buffer: print("Warning: _save_failed_rollouts_to_jsonl called but buffer is empty!") return buffer_size = len(self.failed_rollouts_to_save_buffer) print(f"Starting save of {buffer_size} failed rollout groups to JSONL file...") try: if not os.path.exists(self.datadumps_dir): os.makedirs(self.datadumps_dir) print(f"Created directory: {self.datadumps_dir}") except Exception as e: print(f"Error creating directory {self.datadumps_dir}: {e}") return filename = os.path.join( self.datadumps_dir, f"instruction_following_failed_rollouts_{self.run_uuid}_{self.failed_save_file_batch_num:04d}.jsonl", ) try: with open(filename, "w", encoding="utf-8") as f: for rollout_dict in self.failed_rollouts_to_save_buffer: json.dump(rollout_dict, f, ensure_ascii=False) f.write("\n") print( f"Successfully saved {buffer_size} failed rollout groups to {filename}" ) self.failed_save_file_batch_num += 1 self.failed_rollouts_to_save_buffer.clear() except Exception as e: print(f"Error saving failed rollouts to {filename}: {e}") # ----- IFEval Verifier Functions and Map ----- # adapted from https://github.com/allenai/open-instruct/blob/main/scripts/eval_constraints/if_functions.py # Helper function for verify_keyword_frequency, moved import re to top level def _extract_words(text: str) -> List[str]: return re.findall(r"\\b\\w+\\b", text.lower()) # include keywords: Include keywords {keyword1}, {keyword2} in your response def verify_keywords(text: str, keyword_list: List[str]) -> bool: response_lower = text.lower() return all(keyword.lower() in response_lower for keyword in keyword_list) # Keyword Frequency: In your response, the word {word} should appear {N} times. def verify_keyword_frequency(text: str, word: str, N: int) -> bool: text_lower = text.lower() keyword_lower = word.lower() words = _extract_words(text_lower) actual_count = sum(1 for w in words if w == keyword_lower) return actual_count == N # Forbidden Words: Do not include keywords {forbidden words} in the response. def validate_forbidden_words(text: str, forbidden_words: List[str]) -> bool: text_lower = text.lower() return not any(word.lower() in text_lower for word in forbidden_words) # Letter Frequency : In your response, the letter {letter} should appear {N} times. def verify_letter_frequency(text: str, letter: str, N: int) -> bool: if len(letter) != 1: # This should ideally raise ValueError, but for RL reward, return False return False actual_count = text.count(letter) return actual_count == N # Response Language: Your ENTIRE response should be in {language}, no other language is allowed. def validate_response_language(text: str, language: str) -> bool: try: detected_language = detect(text) return detected_language == language except LangDetectException: # Catching specific exception from detect() print( f"Warning: langdetect failed to detect language for text: '{text[:50]}...'" ) return False # Number Paragraphs: Your response should contain {N} paragraphs. You separate paragraphs using the markdown divider: # * * * def verify_paragraph_count(text: str, N: int) -> bool: def clean_text(txt: str) -> str: return "\\n".join(line.strip() for line in txt.splitlines()).strip() cleaned_text = clean_text(text) # Paragraphs are separated by '* * *'. N dividers mean N+1 paragraphs. # If the text IS paragraphs, then N paragraphs will have N-1 dividers. # The prompt implies N paragraphs are expected. # If N=1, 0 dividers. If N=2, 1 divider. So, count of parts = N. paragraphs = cleaned_text.split("* * *") actual_count = len(paragraphs) # Verify each split resulted in non-empty content, if text itself is not empty if not cleaned_text and N == 0: return True # 0 paragraphs, empty text if not cleaned_text and N > 0: return False # This check might be too strict if empty paragraphs are allowed by the constraint definition # If "paragraph" implies non-empty content: # return len(valid_paragraphs) == N and actual_count == N # If constraint just means N segments separated by dividers: return actual_count == N # Number Words: Answer with at least / around / at most {N} words def validate_word_constraint(text: str, N: int, quantifier: str) -> bool: words = text.strip().split() actual_count = len(words) tolerance = max(round(N * 0.1), 1) # For 'around' if quantifier == "at least": return actual_count >= N elif quantifier == "at most": return actual_count <= N elif quantifier == "around": return abs(actual_count - N) <= tolerance return False # Number Sentences: Answer with at least / around / at most {N} sentences. def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool: # Basic sentence splitting, might need more robust NLP for complex cases sentences = re.split( r"(?= N elif quantifier == "around": # "around" for sentences usually means exact or +/-1 return abs(actual_count - N) <= 1 elif quantifier == "at most": return actual_count <= N return False # Number Paragraphs + First Word in i-th Paragraph def validate_paragraphs(text: str, N: int, first_word: str, i: int) -> bool: # Paragraphs separated by double line breaks paragraphs = text.split("\\n\\n") if len(paragraphs) != N: return False # i is 1-indexed for paragraph number if not (1 <= i <= len(paragraphs)): return False # Check first word of the i-th paragraph # .strip() to handle leading/trailing whitespace in paragraph # .split()[0] to get the first word try: actual_first_word = paragraphs[i - 1].strip().split()[0] # Case-insensitive comparison for first_word might be more robust return actual_first_word.lower() == first_word.lower() except IndexError: # Handles empty paragraph or paragraph without words return False # Postscript: At the end of your response, please explicitly add a postscript starting with {postscript marker} def verify_postscript(text: str, postscript_marker: str) -> bool: marker_index = text.rfind(postscript_marker) # Find last occurrence if marker_index == -1: return False # Check if it's truly a postscript (i.e., near the end, and has content after marker) # This interpretation: marker exists, and something follows it OR it's at the very end. # The original IFEval might have a stricter definition (e.g. specific distance from end) # A simple check: marker is present and the text from marker to end is mostly the postscript. # For RL, simpler: marker is present and is not just prefix of a word. # Test if the marker is at a word boundary if it's not the start of the string if ( marker_index > 0 and text[marker_index - 1].isalnum() and postscript_marker[0].isalnum() ): # Avoid matching mid-word, e.g. "script" in "postscript" if marker is "script" # This check is heuristic. A regex with word boundaries might be better. pass # Heuristic, might need refinement # Check if content exists after marker, or if marker itself is the end remaining_text = text[marker_index:].strip() return len(remaining_text) >= len(postscript_marker.strip()) # Number Placeholder: The response must contain at least {N} placeholders ... [address]. def validate_placeholders(text: str, N: int) -> Tuple[bool, List[str]]: placeholders_found = re.findall(r"\\[(.*?)\\]", text) # Matches [content] return len(placeholders_found) >= N, placeholders_found # Number Bullets: Your answer must contain exactly {N} bullet points. * This is a point. def verify_bullet_points( text: str, N: int ) -> bool: # Original had tuple[bool,str] in doc, bool in code lines = text.splitlines() # Markdown bullets usually start with '*', '-', or '+' followed by a space. bullet_points = [ line.strip() for line in lines if re.match(r"^(\\s*)[\\*\\-\\+]\\s+", line.strip()) ] return len(bullet_points) == N # Title: Your answer must contain a title, wrapped in double angular brackets, such as <>. def validate_title(text: str) -> bool: return bool(re.search(r"<<(.*?)>>", text)) # Choose: From Answer with one of the following options: {options} def validate_choice(text: str, options: List[str]) -> bool: # Assuming 'text' should be one of the 'options' exactly, or contain one of them. # The original prompt "Answer with one of..." implies the response *is* one of the options. # Case-insensitive comparison for robustness. text_cleaned = text.strip().lower() return any(text_cleaned == opt.strip().lower() for opt in options) # Minimum Number Highlighted Section: Highlight at least {N} sections ... *highlighted section* def validate_highlighted_sections(text: str, N: int) -> bool: # Markdown italics/bold *highlight* or **highlight** # This regex looks for single asterisks: *content* matches = re.findall( r"\*(.*?)(?= N # Multiple Sections: Your response must have {N} sections. Mark ... with {section splitter} X. def validate_sections(text: str, N: int, section_splitter: str) -> bool: # Example: section_splitter = "Section" -> "Section 1", "Section 2" # This implies the splitter itself might include a number or be just the prefix. # If splitter is "---", then text.split("---"). # If splitter is "Topic X:", this is more complex. # Assuming a simple string split is intended by the original IFEval function. # The prompt phrasing "Mark the beginning of each section with {section splitter} X" # suggests counting occurrences of the splitter pattern. # If section_splitter is like "SECTION", we'd look for "SECTION 1", "SECTION 2", ... # This is hard to generalize perfectly without knowing how IFEval defines 'X'. # Simplest: count occurrences of the base splitter string. # sections = text.split(section_splitter) # num_sections = len(sections) -1 if sections[0].strip() == "" else len(sections) # A slightly more robust way for "Splitter X": # Count how many times "splitter" followed by something (like a number) appears. # Example: if splitter is "Chapter", we look for "Chapter 1", "Chapter ...". # This regex is a placeholder for more specific logic IFEval might use. # Let's use a simple count of the splitter string for now. # This might need to be adjusted based on IFEval's exact expectation for "X". # For "SECTION 1.", "SECTION 2.", if splitter is "SECTION ": actual_sections = len( re.findall( re.escape(section_splitter) + r"\\s*\\d*[:\\.\\s]", text, re.IGNORECASE ) ) # If N=0 and no splitters, it's true. If N>0 and no splitters, false. if N == 0: return actual_sections == 0 return actual_sections == N # JSON Format : Entire output should be wrapped in JSON format. def validate_json_format(text: str) -> bool: try: json.loads(text.strip()) # .strip() to handle leading/trailing whitespace return True except json.JSONDecodeError: return False # Repeat Prompt: First, repeat the request without change, then give your answer def validate_repeat_prompt(text: str, original_prompt: str) -> bool: # Normalize whitespace for comparison robustness text_norm = " ".join(text.strip().split()) original_prompt_norm = " ".join(original_prompt.strip().split()) return text_norm.startswith(original_prompt_norm) # Two Responses: Give two different responses. Separated by 6 asterisk symbols: ******. def validate_two_responses(text: str) -> bool: if text.count("******") == 1: parts = text.split("******") if len(parts) == 2: # Check if parts are non-empty and different resp1 = parts[0].strip() resp2 = parts[1].strip() return bool(resp1 and resp2 and resp1 != resp2) return False # All Uppercase: Your entire response should be in English, capital letters only. def validate_uppercase(text: str) -> bool: # Check if it has letters and all letters are uppercase if not any( c.isalpha() for c in text ): # No letters, technically not violating "all capital" return True # Or False, depending on interpretation of "response" return text == text.upper() # All Lowercase: Your entire response should be in English, and in all lowercase letters. def validate_lowercase(text: str) -> bool: if not any(c.isalpha() for c in text): return True return text == text.lower() # Frequency of All-capital Words def validate_frequency_capital_words(text: str, N: int, quantifier: str) -> bool: # Words with all capital letters, e.g., "NASA", "AI". Min 2 chars to be a "word". capital_words = re.findall(r"\\b[A-Z]{2,}\\b", text) actual_count = len(capital_words) tolerance = max(round(N * 0.1), 1) # For 'around' if quantifier == "at least": return actual_count >= N elif quantifier == "at most": return actual_count <= N elif ( quantifier == "around" ): # Using exact for 'around' with capital words unless specified return abs(actual_count - N) <= tolerance # Or just actual_count == N return False # End Checker: Finish your response with this exact phrase {end phrase}. def validate_end(text: str, end_phrase: str) -> bool: # Normalize whitespace at the end of text for robustness return text.strip().endswith(end_phrase.strip()) # Quotation: Wrap your entire response with double quotation marks. def validate_quotation(text: str) -> bool: stripped_text = text.strip() return stripped_text.startswith('"') and stripped_text.endswith('"') # No Commas: In your entire response, refrain from the use of any commas. def validate_no_commas(text: str) -> bool: return "," not in text IF_FUNCTIONS_MAP = { "verify_keywords": verify_keywords, "verify_keyword_frequency": verify_keyword_frequency, "validate_forbidden_words": validate_forbidden_words, "verify_letter_frequency": verify_letter_frequency, "validate_response_language": validate_response_language, "verify_paragraph_count": verify_paragraph_count, "validate_word_constraint": validate_word_constraint, "verify_sentence_constraint": verify_sentence_constraint, "validate_paragraphs": validate_paragraphs, "verify_postscript": verify_postscript, "validate_placeholders": validate_placeholders, "verify_bullet_points": verify_bullet_points, "validate_title": validate_title, "validate_choice": validate_choice, "validate_highlighted_sections": validate_highlighted_sections, "validate_sections": validate_sections, "validate_json_format": validate_json_format, "validate_repeat_prompt": validate_repeat_prompt, "validate_two_responses": validate_two_responses, "validate_uppercase": validate_uppercase, "validate_lowercase": validate_lowercase, "validate_frequency_capital_words": validate_frequency_capital_words, "validate_end": validate_end, "validate_quotation": validate_quotation, "validate_no_commas": validate_no_commas, } if __name__ == "__main__": InstructionFollowingEnv.cli()