diff --git a/environments/instruction_following_algorithm_environment.py b/environments/instruction_following_algorithm_environment.py index 7e1e13b2..40dfb9d9 100644 --- a/environments/instruction_following_algorithm_environment.py +++ b/environments/instruction_following_algorithm_environment.py @@ -1,12 +1,13 @@ import json import random import re -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple import wandb -from datasets import load_dataset, Dataset # Added Dataset for dummy data +from datasets import Dataset, load_dataset # Added Dataset for dummy data +from langdetect import LangDetectException, detect +from pydantic import Field # Added import for Field from tqdm.asyncio import tqdm_asyncio -from langdetect import detect, LangDetectException from atroposlib.envs.base import ( APIServerConfig, @@ -17,7 +18,6 @@ from atroposlib.envs.base import ( ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer -from pydantic import Field # Added import for Field # System prompt can be reused or adapted for instruction following tasks system_prompt = ( @@ -30,15 +30,20 @@ system_prompt = ( class IFConfig(BaseEnvConfig): dataset_name: str = Field("allenai/RLVR-IFeval", description="Default dataset name") - dataset_config_name: Optional[str] = Field(None, description="Dataset config name, if any") - test_set_ratio: float = Field(0.05, description="The ratio of the selected dataset for testing") + dataset_config_name: Optional[str] = Field( + None, description="Dataset config name, if any" + ) + test_set_ratio: float = Field( + 0.05, description="The ratio of the selected dataset for testing" + ) class InstructionFollowingEnv(BaseEnv): - env_config_cls = IFConfig # Added env_config_cls for IFConfig + env_config_cls = IFConfig # Added env_config_cls for IFConfig + def __init__( self, - config: IFConfig, # Changed BaseEnvConfig to IFConfig + config: IFConfig, # Changed BaseEnvConfig to IFConfig server_configs: List[APIServerConfig], slurm=True, testing=False, @@ -50,24 +55,26 @@ class InstructionFollowingEnv(BaseEnv): # self.completion_lengths = [] # Kept from SingleToolCallingEnv, assess utility @classmethod - def config_init(self) -> Tuple[IFConfig, List[APIServerConfig]]: # Changed BaseEnvConfig to IFConfig + def config_init( + self, + ) -> Tuple[IFConfig, List[APIServerConfig]]: # Changed BaseEnvConfig to IFConfig # Configuration for the Instruction Following Environment - env_config = IFConfig( # Changed BaseEnvConfig to IFConfig - tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", - group_size=16, + env_config = IFConfig( # Changed BaseEnvConfig to IFConfig + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=16, use_wandb=True, - rollout_server_url="http://localhost:8000", + rollout_server_url="http://localhost:8000", total_steps=500, - batch_size=1024, + batch_size=1024, steps_per_eval=20, - max_token_length=1024 * 16, + max_token_length=1024 * 16, inference_weight=1.0, - wandb_name="instruction_following_rlvr_ifeval", # Specific WandB project name + wandb_name="instruction_following_rlvr_ifeval", # Specific WandB project name eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, - dataset_name="allenai/RLVR-IFeval", # Default dataset - dataset_config_name=None, # RLVR-IFeval doesn't have a specific config name, uses 'default' - test_set_ratio=0.05 # The ratio of the selelcted dataset in % + dataset_name="allenai/RLVR-IFeval", # Default dataset + dataset_config_name=None, # RLVR-IFeval doesn't have a specific config name, uses 'default' + test_set_ratio=0.05, # The ratio of the selelcted dataset in % ) # Server configurations can be similar to SingleToolCallingEnv or adjusted server_configs = [ @@ -103,7 +110,7 @@ class InstructionFollowingEnv(BaseEnv): self.percent_correct_buffer ) / len(self.percent_correct_buffer) except ZeroDivisionError: - pass # Buffer might be empty + pass # Buffer might be empty self.percent_correct_buffer = list() for item in self.eval_metrics: @@ -118,41 +125,68 @@ class InstructionFollowingEnv(BaseEnv): Each item from RLVR-IFeval is expected to have: - 'messages': A list of dictionaries, e.g., [{'role': 'user', 'content': 'instruction...'}] - 'ground_truth': A JSON string containing 'func_name' and arguments for the verifier. - + The method will parse these to produce items for the environment with: - 'prompt': The user's instruction string. - 'func_name': The string name of the verifier function. - 'args': A dictionary of arguments for that verifier function. """ dataset_name = getattr(self.config, "dataset_name", "allenai/RLVR-IFeval") - dataset_config_name = getattr(self.config, "dataset_config_name", None) # Default is None, RLVR-IFeval has no sub-config + dataset_config_name = getattr( + self.config, "dataset_config_name", None + ) # Default is None, RLVR-IFeval has no sub-config processed_items = [] try: - print(f"Attempting to load dataset: {dataset_name}, config: {dataset_config_name if dataset_config_name else 'default'}") + print( + f"Attempting to load dataset: {dataset_name}, " + f"config: {dataset_config_name if dataset_config_name else 'default'}" + ) if dataset_config_name: - full_dataset_raw = load_dataset(dataset_name, dataset_config_name, split="train", trust_remote_code=True) + full_dataset_raw = load_dataset( + dataset_name, + dataset_config_name, + split="train", + trust_remote_code=True, + ) else: - full_dataset_raw = load_dataset(dataset_name, split="train", trust_remote_code=True) - print(f"Successfully loaded raw dataset. Number of items: {len(full_dataset_raw)}") + full_dataset_raw = load_dataset( + dataset_name, split="train", trust_remote_code=True + ) + print( + f"Successfully loaded raw dataset. Number of items: {len(full_dataset_raw)}" + ) for i, item in enumerate(full_dataset_raw): # Extract prompt from 'messages' field item_messages = item.get("messages") - if not item_messages or not isinstance(item_messages, list) or len(item_messages) == 0: - print(f"Warning: Item {i} has invalid or empty 'messages' field. Skipping. Item: {item}") + if ( + not item_messages + or not isinstance(item_messages, list) + or len(item_messages) == 0 + ): + print( + f"Warning: Item {i} has invalid or empty 'messages' field. Skipping. Item: {item}" + ) continue # Assuming the relevant prompt is the content of the first message in the list # (or last, if multiple user messages were possible, but IFEval is typically single user instruction) prompt_text = item_messages[0].get("content") if not prompt_text: - print(f"Warning: Item {i} '{item_messages[0]}' has no content. Skipping.") + print( + f"Warning: Item {i} '{item_messages[0]}' has no content. Skipping." + ) continue # Get the ground_truth JSON string ground_truth_json_str = item.get("ground_truth") - if not ground_truth_json_str or not isinstance(ground_truth_json_str, str): - print(f"Warning: Item {i} missing or has invalid 'ground_truth' string. Skipping. Prompt: {prompt_text[:50]}...") + if not ground_truth_json_str or not isinstance( + ground_truth_json_str, str + ): + print( + f"Warning: Item {i} missing or has invalid 'ground_truth' string. Skipping. " + f"Prompt: {prompt_text[:50]}..." + ) continue try: @@ -160,63 +194,93 @@ class InstructionFollowingEnv(BaseEnv): if not isinstance(parsed_gt, dict): raise ValueError("Parsed ground_truth is not a dictionary.") except (json.JSONDecodeError, ValueError) as e: - print(f"Warning: Could not parse 'ground_truth' JSON for item {i}. Error: {e}. GT String: '{ground_truth_json_str}'. Prompt: {prompt_text[:50]}... Skipping.") + print( + f"Warning: Could not parse 'ground_truth' JSON for item {i}. Error: {e}. " + f"GT String: '{ground_truth_json_str}'. Prompt: {prompt_text[:50]}... Skipping." + ) continue func_name_from_gt = parsed_gt.get("func_name") if not func_name_from_gt: - print(f"Warning: Item {i} parsed 'ground_truth' has no 'func_name'. GT: {parsed_gt}. Prompt: {prompt_text[:50]}... Skipping.") + print( + f"Warning: Item {i} parsed 'ground_truth' has no 'func_name'. GT: {parsed_gt}. " + f"Prompt: {prompt_text[:50]}... Skipping." + ) continue - + if func_name_from_gt not in IF_FUNCTIONS_MAP: - print(f"Warning: func_name '{func_name_from_gt}' in item {i} not in IF_FUNCTIONS_MAP. Prompt: {prompt_text[:50]}... Skipping.") + print( + f"Warning: func_name '{func_name_from_gt}' in item {i} not in IF_FUNCTIONS_MAP. " + f"Prompt: {prompt_text[:50]}... Skipping." + ) continue # Prepare args for the verifier function: remove func_name and keep others. # Verifier functions will only use args they expect. - args_dict = {k: v for k, v in parsed_gt.items() if k != "func_name" and v is not None} - - processed_items.append({ - "prompt": prompt_text, - "func_name": func_name_from_gt, - "args": args_dict, - "original_constraints_for_logging": str(item.get("constraint", "")), # For logging, from RLVR-IFeval structure - "expected_response_for_logging": "" # RLVR-IFeval doesn't seem to have a sample good response directly - }) + args_dict = { + k: v + for k, v in parsed_gt.items() + if k != "func_name" and v is not None + } + + processed_items.append( + { + "prompt": prompt_text, + "func_name": func_name_from_gt, + "args": args_dict, + "original_constraints_for_logging": str( + item.get("constraint", "") + ), # For logging, from RLVR-IFeval structure + "expected_response_for_logging": "", # RLVR-IFeval doesn't seem to have a sample good response directly # noqa: E501 + } + ) if not processed_items: - print("Warning: No items successfully processed from the dataset. Check dataset format/content or parsing logic.") - raise ValueError("Dataset processing resulted in no valid items for RLVR-IFeval. Cannot proceed without data.") - + print( + "Warning: No items successfully processed from the dataset. " + "Check dataset format/content or parsing logic." + ) + raise ValueError( + "Dataset processing resulted in no valid items for RLVR-IFeval. Cannot proceed without data." + ) + full_dataset = Dataset.from_list(processed_items) - print(f"Successfully processed {len(full_dataset)} items from dataset '{dataset_name}'.") + print( + f"Successfully processed {len(full_dataset)} items from dataset '{dataset_name}'." + ) except Exception as e: # This block is a fallback if the primary dataset loading/processing fails catastrophically. - # For RLVR-IFeval, a failure here suggests issues with Hugging Face access, dataset integrity, or fundamental code errors. - print(f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. Using a DUMMY dataset as fallback.") + # For RLVR-IFeval, a failure here suggests issues with Hugging Face access, dataset integrity, or fundamental code errors. # noqa: E501 + print( + f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. " + f"Using a DUMMY dataset as fallback." + ) dummy_data_for_fallback = [ { "prompt": "Dummy Instruction 1: Ensure your response contains the word 'example'.", "func_name": "verify_keywords", - "args": {"keyword_list": ["example"]}, + "args": {"keyword_list": ["example"]}, "original_constraints_for_logging": "Contains 'example'", - "expected_response_for_logging": "This is an example response." + "expected_response_for_logging": "This is an example response.", }, - { + { "prompt": "Dummy Instruction 2: Output a valid JSON with key 'data' and value 'test'.", "func_name": "validate_json_format", - "args": {}, + "args": {}, "original_constraints_for_logging": "Output valid JSON.", - "expected_response_for_logging": "{\\\"data\\\": \\\"test\\\"}" - } + "expected_response_for_logging": '{\\"data\\": \\"test\\"}', + }, ] full_dataset = Dataset.from_list(dummy_data_for_fallback) - print(f"Initialized with DUMMY dataset of {len(full_dataset)} items due to previous errors.") - + print( + f"Initialized with DUMMY dataset of {len(full_dataset)} items " + f"due to previous errors." + ) + full_dataset = full_dataset.shuffle(seed=42) - - actual_test_size = self.config.test_set_ratio # Read from config + + actual_test_size = self.config.test_set_ratio # Read from config num_items = len(full_dataset) if num_items == 0: @@ -227,48 +291,59 @@ class InstructionFollowingEnv(BaseEnv): print("Warning: Dataset has only 1 item. Using it for both train and test.") self.train = full_dataset self.test = full_dataset - else: # num_items > 1 + else: # num_items > 1 # Ensure test_size results in at least 1 item for test set if possible, but not more than train set - if num_items < 5 : # For 2,3,4 items, make test size 1 - min_test_items = 1 - else: # For 5+ items, 20% is fine - min_test_items = max(1, int(num_items * actual_test_size)) + if num_items < 5: # For 2,3,4 items, make test size 1 + min_test_items = 1 + else: # For 5+ items, 20% is fine + min_test_items = max(1, int(num_items * actual_test_size)) # Ensure test split is not too large, e.g. not more than 50% unless dataset is very small # And ensure train always has at least one item if num_items > 1 calculated_test_size = min_test_items / num_items - if calculated_test_size >= 0.5 and num_items > 2: # If test is 50% or more and we have 3+ items - calculated_test_size = (num_items -1) / num_items # Make train have at least 1 + if ( + calculated_test_size >= 0.5 and num_items > 2 + ): # If test is 50% or more and we have 3+ items + calculated_test_size = ( + num_items - 1 + ) / num_items # Make train have at least 1 - split_dataset = full_dataset.train_test_split(test_size=calculated_test_size, seed=42) + split_dataset = full_dataset.train_test_split( + test_size=calculated_test_size, seed=42 + ) self.train = split_dataset["train"] self.test = split_dataset["test"] # Final check for empty train/test after split, should not happen with logic above if num_items > 0 if len(self.train) == 0 and len(self.test) > 0: - print("Warning: Train set empty after split, test set has data. This is unusual. Swapping.") - self.train = self.test # Fallback, though indicates issue + print( + "Warning: Train set empty after split, test set has data. This is unusual. Swapping." + ) + self.train = self.test # Fallback, though indicates issue elif len(self.test) == 0 and len(self.train) > 0: - print("Warning: Test set empty after split, train set has data. Using full train set for test as well.") - self.test = self.train - + print( + "Warning: Test set empty after split, train set has data. Using full train set for test as well." + ) + self.test = self.train self.iter = 0 - print(f"Dataset setup complete. Train size: {len(self.train)}, Test size: {len(self.test)}") + print( + f"Dataset setup complete. Train size: {len(self.train)}, Test size: {len(self.test)}" + ) - - async def _get_score_from_verifier(self, model_response_text: str, func_name: str, args: Dict) -> float: + async def _get_score_from_verifier( + self, model_response_text: str, func_name: str, args: Dict + ) -> float: """Helper to call verifier function and get a numerical score.""" if func_name not in IF_FUNCTIONS_MAP: - print(f"Warning: Verifier function '{func_name}' not found in IF_FUNCTIONS_MAP.") + print( + f"Warning: Verifier function '{func_name}' not found in IF_FUNCTIONS_MAP." + ) return 0.0 verifier_func = IF_FUNCTIONS_MAP[func_name] - + raw_score = None try: - # For validate_response_language, langdetect is now imported at the top. - # Specific argument handling for functions that don't take generic **args - # or have special return types should be done before the generic call. if func_name == "validate_placeholders": # validate_placeholders expects 'text' and 'N', returns (bool, list) raw_score = verifier_func(model_response_text, N=args.get("N")) @@ -276,33 +351,60 @@ class InstructionFollowingEnv(BaseEnv): # verify_bullet_points expects 'text' and 'N', returns bool (or was (bool,str) in one doc) # Assuming it now consistently returns bool as per our integrated version raw_score = verifier_func(model_response_text, N=args.get("N")) - # Add other specific handlers here if necessary, otherwise use generic **args + # Special handler for repeat-prompt, which only wants (text, original_prompt) + elif func_name == "validate_repeat_prompt": + raw_score = verifier_func( + model_response_text, args.get("original_prompt", "") + ) + # Generic handler: only pass kwargs the function actually declares else: - raw_score = verifier_func(model_response_text, **args) + from inspect import signature - except LangDetectException: # Specifically catch for language detection issues - print(f"Warning: langdetect failed for func_name '{func_name}'. Scoring as incorrect.") + sig = signature(verifier_func) + # drop the 'text' parameter and keep only declared kwargs + valid_params = [p for p in sig.parameters if p != "text"] + filtered_args = { + k: args[k] + for k in valid_params + if k in args and args[k] is not None + } + raw_score = verifier_func(model_response_text, **filtered_args) + + except LangDetectException: # Specifically catch for language detection issues + print( + f"Warning: langdetect failed for func_name '{func_name}'. Scoring as incorrect." + ) return 0.0 except ImportError as e: # This might happen if a function tries a lazy import that fails (not langdetect now) - print(f"Warning: ImportError during verifier function '{func_name}': {e}. Check dependencies.") + print( + f"Warning: ImportError during verifier function '{func_name}': {e}. Check dependencies." + ) return 0.0 except TypeError as e: - print(f"TypeError calling {func_name} with args {args}: {e}. Text: '{model_response_text[:100]}...'") + print( + f"TypeError calling {func_name} with args {args}: {e}. Text: '{model_response_text[:100]}...'" + ) return 0.0 - except Exception as e: # Catch any other unexpected error from a verifier - print(f"Unexpected error in verifier function '{func_name}' with args {args}: {e}") + except Exception as e: # Catch any other unexpected error from a verifier + print( + f"Unexpected error in verifier function '{func_name}' with args {args}: {e}" + ) return 0.0 # Convert boolean or tuple[boolean, ...] to float score if isinstance(raw_score, tuple): - score_value = float(raw_score[0]) # Assuming the first element is the boolean score + score_value = float( + raw_score[0] + ) # Assuming the first element is the boolean score elif isinstance(raw_score, bool): score_value = float(raw_score) else: - print(f"Warning: Verifier '{func_name}' returned unexpected type: {type(raw_score)}. Expected bool or tuple.") + print( + f"Warning: Verifier '{func_name}' returned unexpected type: {type(raw_score)}. Expected bool or tuple." + ) score_value = 0.0 - + return score_value async def rollout_and_score_eval(self, test_item: Dict): @@ -322,15 +424,19 @@ class InstructionFollowingEnv(BaseEnv): completion = await self.server.completion( prompt=prompt_str, n=1, - max_tokens=self.config.max_token_length, # Use config for max_tokens - temperature=0.7, # Temperature for eval, can be 0 for deterministic + max_tokens=self.config.max_token_length, # Use config for max_tokens + temperature=0.7, # Temperature for eval, can be 0 for deterministic split="eval", ) model_response_text = completion.choices[0].text - score_value = await self._get_score_from_verifier(model_response_text, func_name, args_for_verifier) - - return score_value # Returns 1.0 for correct, 0.0 for incorrect based on verifier + score_value = await self._get_score_from_verifier( + model_response_text, func_name, args_for_verifier + ) + + return ( + score_value # Returns 1.0 for correct, 0.0 for incorrect based on verifier + ) async def evaluate(self, *args, **kwargs): # Evaluates the model on the test set @@ -340,21 +446,22 @@ class InstructionFollowingEnv(BaseEnv): return eval_tasks = [] - for test_item_dict in self.test: # self.test contains dicts after setup + for test_item_dict in self.test: # self.test contains dicts after setup eval_tasks.append(self.rollout_and_score_eval(test_item_dict)) - + scores = await tqdm_asyncio.gather(*eval_tasks) - - if not scores: # If gather returns empty list - percent_correct = 0.0 + + if not scores: # If gather returns empty list + percent_correct = 0.0 else: percent_correct = sum(scores) / len(scores) - + self.eval_metrics.append(("eval/percent_correct", percent_correct)) print(f"Evaluation percent correct: {percent_correct}") - - async def collect_trajectories(self, item: Item) -> Tuple[Optional[ScoredDataGroup], List]: + async def collect_trajectories( + self, item: Item + ) -> Tuple[Optional[ScoredDataGroup], List]: # item = (prompt_messages_tuple, answer_info_dict) # answer_info_dict = {"func_name": ..., "args": ...} prompt_messages_list = [dict(msg_fset) for msg_fset in item[0]] @@ -368,20 +475,22 @@ class InstructionFollowingEnv(BaseEnv): prompt=prompt_str, n=self.config.group_size, max_tokens=self.config.max_token_length, - temperature=0.8, # Temperature for diverse responses during training rollouts + temperature=0.8, # Temperature for diverse responses during training rollouts ) to_score_list = [] for choice in completions.choices: - trajectory_messages = [dict(msg_fset) for msg_fset in item[0]] # Fresh copy + trajectory_messages = [dict(msg_fset) for msg_fset in item[0]] # Fresh copy trajectory_messages.append({"role": "assistant", "content": choice.text}) - to_score_list.append((tuple(trajectory_messages), answer_info)) # Pass answer_info + to_score_list.append( + (tuple(trajectory_messages), answer_info) + ) # Pass answer_info if not to_score_list: return None, [] - + scored_data = await self.score(to_score_list) - to_backlog = [] # Backlog not currently used but part of signature + to_backlog = [] # Backlog not currently used but part of signature return scored_data, to_backlog @@ -390,7 +499,7 @@ class InstructionFollowingEnv(BaseEnv): data = {} data["iter"] = self.iter super().save_checkpoint(step, data) - + async def score( self, rollout_group_data: List[Tuple[tuple, Dict]] ) -> Optional[ScoredDataGroup]: @@ -408,20 +517,22 @@ class InstructionFollowingEnv(BaseEnv): # The 'answer_info' (func_name, args) is consistent for all items in this group, # as it comes from the same initial prompt. # We can extract it once if needed, but it's passed per item. - - random.shuffle(rollout_group_data) # Shuffle to avoid bias + + random.shuffle(rollout_group_data) # Shuffle to avoid bias for trajectory_item in rollout_group_data: full_trajectory_messages = trajectory_item[0] - answer_info = trajectory_item[1] # {"func_name": ..., "args": ...} - + answer_info = trajectory_item[1] # {"func_name": ..., "args": ...} + model_response_text = full_trajectory_messages[-1]["content"] func_name = answer_info["func_name"] args_for_verifier = answer_info["args"] # Get score (1.0 for correct, 0.0 for incorrect from verifier) - score_value = await self._get_score_from_verifier(model_response_text, func_name, args_for_verifier) - + score_value = await self._get_score_from_verifier( + model_response_text, func_name, args_for_verifier + ) + # Map to reward: 1.0 for correct, 0 for incorrect reward = 1.0 if score_value == 1.0 else 0 @@ -433,7 +544,9 @@ class InstructionFollowingEnv(BaseEnv): masks = out_dict["masks"] # Filter out examples with insufficient context (too short) - if sum(1 for m_val in masks if m_val != -100) < 10: # At least 10 non-masked tokens + if ( + sum(1 for m_val in masks if m_val != -100) < 10 + ): # At least 10 non-masked tokens continue scores_container["tokens"].append(tokens) @@ -443,42 +556,48 @@ class InstructionFollowingEnv(BaseEnv): # Stop if we have enough examples for the group if len(scores_container["tokens"]) >= self.config.group_size: break - - if not scores_container["tokens"]: # No valid items collected + + if not scores_container["tokens"]: # No valid items collected return None # Record success rate for logging (based on positive rewards) for rwd in scores_container["scores"]: - self.percent_correct_buffer.append(max(0, rwd)) # If reward is 1.0, it's a success + self.percent_correct_buffer.append( + max(0, rwd) + ) # If reward is 1.0, it's a success # Optional: Apply length penalty if all responses are correct (reward 1.0) # This logic is from SingleToolCallingEnv, may need adjustment for IF if all(s == 1.0 for s in scores_container["scores"]): token_lengths = [len(t) for t in scores_container["tokens"]] if not token_lengths or max(token_lengths) == 0: - return scores_container # Avoid division by zero, or if all empty + return scores_container # Avoid division by zero, or if all empty max_allowed_length = self.config.max_token_length # Threshold can be adjusted, e.g., 75% of max_token_length - length_threshold = max_allowed_length * 0.75 + length_threshold = max_allowed_length * 0.75 penalized_scores = [] for i, length in enumerate(token_lengths): - original_score = scores_container["scores"][i] # Should be 1.0 here + original_score = scores_container["scores"][i] # Should be 1.0 here if length <= length_threshold: penalized_scores.append(original_score) else: # Linear penalty for exceeding threshold - penalty_factor = (length - length_threshold) / (max_allowed_length - length_threshold) - penalty_factor = min(penalty_factor, 1.0) # Cap penalty factor at 1 + penalty_factor = (length - length_threshold) / ( + max_allowed_length - length_threshold + ) + penalty_factor = min(penalty_factor, 1.0) # Cap penalty factor at 1 # Penalized score scales from original_score down to original_score * (1-1) = 0 penalized_scores.append(original_score * (1.0 - penalty_factor)) scores_container["scores"] = penalized_scores - # If all scores are identical after potential penalties, no learning signal - if len(set(scores_container["scores"])) <= 1 and len(scores_container["scores"]) > 1 : - return None # Avoid sending data with no variance + if ( + len(set(scores_container["scores"])) <= 1 + and len(scores_container["scores"]) > 1 + ): + return None # Avoid sending data with no variance return scores_container @@ -490,17 +609,21 @@ class InstructionFollowingEnv(BaseEnv): # Return a dummy item to prevent crashes, though this indicates a setup issue dummy_prompt_messages = ( frozenset({"role": "system", "content": system_prompt}.items()), - frozenset({"role": "user", "content": "Dummy instruction: say hello."}.items()) + frozenset( + {"role": "user", "content": "Dummy instruction: say hello."}.items() + ), ) - dummy_answer_info = {"func_name": "verify_keywords", "args": {"keyword_list": ["hello"]}} + dummy_answer_info = { + "func_name": "verify_keywords", + "args": {"keyword_list": ["hello"]}, + } return (dummy_prompt_messages, dummy_answer_info) - - raw_item = self.train[self.iter % len(self.train)] # raw_item is a dict + raw_item = self.train[self.iter % len(self.train)] # raw_item is a dict self.iter += 1 instruction_prompt_text = raw_item["prompt"] - + # Construct messages for the LLM (prompt tuple part of Item) # Using frozenset as required by BaseEnv's Item type hint prompt_messages_tuple = ( @@ -513,31 +636,37 @@ class InstructionFollowingEnv(BaseEnv): "func_name": raw_item["func_name"], "args": raw_item["args"], # Optionally include other info for logging/debugging if needed from raw_item - "original_constraints_for_logging": raw_item.get("original_constraints", ""), - "expected_response_for_logging": raw_item.get("expected_response_for_logging", "") + "original_constraints_for_logging": raw_item.get( + "original_constraints", "" + ), + "expected_response_for_logging": raw_item.get( + "expected_response_for_logging", "" + ), } - + return (prompt_messages_tuple, answer_info) async def add_rollouts_for_wandb( self, - scored_data: ScoredDataGroup, # Assuming single ScoredDataGroup here - item: Item = None, # item = (prompt_messages_tuple, answer_info_dict) + scored_data: ScoredDataGroup, # Assuming single ScoredDataGroup here + item: Item = None, # item = (prompt_messages_tuple, answer_info_dict) ): # Saves rollouts for WandB logging num_keep = self.config.num_rollouts_per_group_for_logging - if num_keep == -1: # Log all rollouts in the group + if num_keep == -1: # Log all rollouts in the group num_keep = len(scored_data["tokens"]) - + # item[1] is the answer_info_dict containing func_name and args constraint_details_for_log = item[1] if item else {} rollout_batch = [] for i in range(min(num_keep, len(scored_data["tokens"]))): - decoded_text = self.tokenizer.decode(scored_data["tokens"][i], skip_special_tokens=True) + decoded_text = self.tokenizer.decode( + scored_data["tokens"][i], skip_special_tokens=True + ) score = scored_data["scores"][i] rollout_batch.append((decoded_text, score, constraint_details_for_log)) - + self.rollouts_for_wandb.append(rollout_batch) # Limit the number of rollout groups stored @@ -548,15 +677,18 @@ class InstructionFollowingEnv(BaseEnv): # ----- IFEval Verifier Functions and Map ----- # adapted from https://github.com/allenai/open-instruct/blob/main/scripts/eval_constraints/if_functions.py + # Helper function for verify_keyword_frequency, moved import re to top level def _extract_words(text: str) -> List[str]: return re.findall(r"\\b\\w+\\b", text.lower()) + # include keywords: Include keywords {keyword1}, {keyword2} in your response def verify_keywords(text: str, keyword_list: List[str]) -> bool: response_lower = text.lower() return all(keyword.lower() in response_lower for keyword in keyword_list) + # Keyword Frequency: In your response, the word {word} should appear {N} times. def verify_keyword_frequency(text: str, word: str, N: int) -> bool: text_lower = text.lower() @@ -565,11 +697,13 @@ def verify_keyword_frequency(text: str, word: str, N: int) -> bool: actual_count = sum(1 for w in words if w == keyword_lower) return actual_count == N + # Forbidden Words: Do not include keywords {forbidden words} in the response. def validate_forbidden_words(text: str, forbidden_words: List[str]) -> bool: text_lower = text.lower() return not any(word.lower() in text_lower for word in forbidden_words) + # Letter Frequency : In your response, the letter {letter} should appear {N} times. def verify_letter_frequency(text: str, letter: str, N: int) -> bool: if len(letter) != 1: @@ -578,13 +712,16 @@ def verify_letter_frequency(text: str, letter: str, N: int) -> bool: actual_count = text.count(letter) return actual_count == N + # Response Language: Your ENTIRE response should be in {language}, no other language is allowed. def validate_response_language(text: str, language: str) -> bool: try: detected_language = detect(text) return detected_language == language - except LangDetectException: # Catching specific exception from detect() - print(f"Warning: langdetect failed to detect language for text: '{text[:50]}...'") + except LangDetectException: # Catching specific exception from detect() + print( + f"Warning: langdetect failed to detect language for text: '{text[:50]}...'" + ) return False @@ -593,18 +730,20 @@ def validate_response_language(text: str, language: str) -> bool: def verify_paragraph_count(text: str, N: int) -> bool: def clean_text(txt: str) -> str: return "\\n".join(line.strip() for line in txt.splitlines()).strip() + cleaned_text = clean_text(text) # Paragraphs are separated by '* * *'. N dividers mean N+1 paragraphs. # If the text IS paragraphs, then N paragraphs will have N-1 dividers. # The prompt implies N paragraphs are expected. # If N=1, 0 dividers. If N=2, 1 divider. So, count of parts = N. - paragraphs = cleaned_text.split("* * *") + paragraphs = cleaned_text.split("* * *") actual_count = len(paragraphs) # Verify each split resulted in non-empty content, if text itself is not empty - if not cleaned_text and N == 0 : return True # 0 paragraphs, empty text - if not cleaned_text and N > 0 : return False + if not cleaned_text and N == 0: + return True # 0 paragraphs, empty text + if not cleaned_text and N > 0: + return False - valid_paragraphs = [p.strip() for p in paragraphs if p.strip()] # This check might be too strict if empty paragraphs are allowed by the constraint definition # If "paragraph" implies non-empty content: # return len(valid_paragraphs) == N and actual_count == N @@ -616,7 +755,7 @@ def verify_paragraph_count(text: str, N: int) -> bool: def validate_word_constraint(text: str, N: int, quantifier: str) -> bool: words = text.strip().split() actual_count = len(words) - tolerance = max(round(N * 0.1), 1) # For 'around' + tolerance = max(round(N * 0.1), 1) # For 'around' if quantifier == "at least": return actual_count >= N @@ -626,10 +765,13 @@ def validate_word_constraint(text: str, N: int, quantifier: str) -> bool: return abs(actual_count - N) <= tolerance return False + # Number Sentences: Answer with at least / around / at most {N} sentences. def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool: # Basic sentence splitting, might need more robust NLP for complex cases - sentences = re.split(r'(? bool: return actual_count <= N return False + # Number Paragraphs + First Word in i-th Paragraph def validate_paragraphs(text: str, N: int, first_word: str, i: int) -> bool: # Paragraphs separated by double line breaks @@ -659,13 +802,13 @@ def validate_paragraphs(text: str, N: int, first_word: str, i: int) -> bool: actual_first_word = paragraphs[i - 1].strip().split()[0] # Case-insensitive comparison for first_word might be more robust return actual_first_word.lower() == first_word.lower() - except IndexError: # Handles empty paragraph or paragraph without words + except IndexError: # Handles empty paragraph or paragraph without words return False # Postscript: At the end of your response, please explicitly add a postscript starting with {postscript marker} def verify_postscript(text: str, postscript_marker: str) -> bool: - marker_index = text.rfind(postscript_marker) # Find last occurrence + marker_index = text.rfind(postscript_marker) # Find last occurrence if marker_index == -1: return False # Check if it's truly a postscript (i.e., near the end, and has content after marker) @@ -674,10 +817,14 @@ def verify_postscript(text: str, postscript_marker: str) -> bool: # A simple check: marker is present and the text from marker to end is mostly the postscript. # For RL, simpler: marker is present and is not just prefix of a word. # Test if the marker is at a word boundary if it's not the start of the string - if marker_index > 0 and text[marker_index-1].isalnum() and postscript_marker[0].isalnum(): + if ( + marker_index > 0 + and text[marker_index - 1].isalnum() + and postscript_marker[0].isalnum() + ): # Avoid matching mid-word, e.g. "script" in "postscript" if marker is "script" # This check is heuristic. A regex with word boundaries might be better. - pass # Heuristic, might need refinement + pass # Heuristic, might need refinement # Check if content exists after marker, or if marker itself is the end remaining_text = text[marker_index:].strip() @@ -686,19 +833,28 @@ def verify_postscript(text: str, postscript_marker: str) -> bool: # Number Placeholder: The response must contain at least {N} placeholders ... [address]. def validate_placeholders(text: str, N: int) -> Tuple[bool, List[str]]: - placeholders_found = re.findall(r'\\[(.*?)\\]', text) # Matches [content] + placeholders_found = re.findall(r"\\[(.*?)\\]", text) # Matches [content] return len(placeholders_found) >= N, placeholders_found + # Number Bullets: Your answer must contain exactly {N} bullet points. * This is a point. -def verify_bullet_points(text: str, N: int) -> bool: # Original had tuple[bool,str] in doc, bool in code +def verify_bullet_points( + text: str, N: int +) -> bool: # Original had tuple[bool,str] in doc, bool in code lines = text.splitlines() # Markdown bullets usually start with '*', '-', or '+' followed by a space. - bullet_points = [line.strip() for line in lines if re.match(r'^(\\s*)[\\*\\-\\+]\\s+', line.strip())] + bullet_points = [ + line.strip() + for line in lines + if re.match(r"^(\\s*)[\\*\\-\\+]\\s+", line.strip()) + ] return len(bullet_points) == N + # Title: Your answer must contain a title, wrapped in double angular brackets, such as <>. def validate_title(text: str) -> bool: - return bool(re.search(r'<<(.*?)>>', text)) + return bool(re.search(r"<<(.*?)>>", text)) + # Choose: From Answer with one of the following options: {options} def validate_choice(text: str, options: List[str]) -> bool: @@ -708,15 +864,19 @@ def validate_choice(text: str, options: List[str]) -> bool: text_cleaned = text.strip().lower() return any(text_cleaned == opt.strip().lower() for opt in options) + # Minimum Number Highlighted Section: Highlight at least {N} sections ... *highlighted section* def validate_highlighted_sections(text: str, N: int) -> bool: # Markdown italics/bold *highlight* or **highlight** # This regex looks for single asterisks: *content* - matches = re.findall(r'\\*(.*?)\\*(?= N + # Multiple Sections: Your response must have {N} sections. Mark ... with {section splitter} X. def validate_sections(text: str, N: int, section_splitter: str) -> bool: # Example: section_splitter = "Section" -> "Section 1", "Section 2" @@ -726,7 +886,7 @@ def validate_sections(text: str, N: int, section_splitter: str) -> bool: # Assuming a simple string split is intended by the original IFEval function. # The prompt phrasing "Mark the beginning of each section with {section splitter} X" # suggests counting occurrences of the splitter pattern. - + # If section_splitter is like "SECTION", we'd look for "SECTION 1", "SECTION 2", ... # This is hard to generalize perfectly without knowing how IFEval defines 'X'. # Simplest: count occurrences of the base splitter string. @@ -736,25 +896,31 @@ def validate_sections(text: str, N: int, section_splitter: str) -> bool: # Count how many times "splitter" followed by something (like a number) appears. # Example: if splitter is "Chapter", we look for "Chapter 1", "Chapter ...". # This regex is a placeholder for more specific logic IFEval might use. - + # Let's use a simple count of the splitter string for now. # This might need to be adjusted based on IFEval's exact expectation for "X". # For "SECTION 1.", "SECTION 2.", if splitter is "SECTION ": - actual_sections = len(re.findall(re.escape(section_splitter) + r'\\s*\\d*[:\\.\\s]', text, re.IGNORECASE)) - + actual_sections = len( + re.findall( + re.escape(section_splitter) + r"\\s*\\d*[:\\.\\s]", text, re.IGNORECASE + ) + ) + # If N=0 and no splitters, it's true. If N>0 and no splitters, false. - if N == 0: return actual_sections == 0 + if N == 0: + return actual_sections == 0 return actual_sections == N # JSON Format : Entire output should be wrapped in JSON format. def validate_json_format(text: str) -> bool: try: - json.loads(text.strip()) # .strip() to handle leading/trailing whitespace + json.loads(text.strip()) # .strip() to handle leading/trailing whitespace return True except json.JSONDecodeError: return False + # Repeat Prompt: First, repeat the request without change, then give your answer def validate_repeat_prompt(text: str, original_prompt: str) -> bool: # Normalize whitespace for comparison robustness @@ -762,6 +928,7 @@ def validate_repeat_prompt(text: str, original_prompt: str) -> bool: original_prompt_norm = " ".join(original_prompt.strip().split()) return text_norm.startswith(original_prompt_norm) + # Two Responses: Give two different responses. Separated by 6 asterisk symbols: ******. def validate_two_responses(text: str) -> bool: if text.count("******") == 1: @@ -773,48 +940,59 @@ def validate_two_responses(text: str) -> bool: return bool(resp1 and resp2 and resp1 != resp2) return False + # All Uppercase: Your entire response should be in English, capital letters only. def validate_uppercase(text: str) -> bool: # Check if it has letters and all letters are uppercase - if not any(c.isalpha() for c in text): # No letters, technically not violating "all capital" - return True # Or False, depending on interpretation of "response" + if not any( + c.isalpha() for c in text + ): # No letters, technically not violating "all capital" + return True # Or False, depending on interpretation of "response" return text == text.upper() + # All Lowercase: Your entire response should be in English, and in all lowercase letters. def validate_lowercase(text: str) -> bool: if not any(c.isalpha() for c in text): return True return text == text.lower() + # Frequency of All-capital Words def validate_frequency_capital_words(text: str, N: int, quantifier: str) -> bool: # Words with all capital letters, e.g., "NASA", "AI". Min 2 chars to be a "word". - capital_words = re.findall(r'\\b[A-Z]{2,}\\b', text) + capital_words = re.findall(r"\\b[A-Z]{2,}\\b", text) actual_count = len(capital_words) - tolerance = max(round(N * 0.1), 1) # For 'around' + tolerance = max(round(N * 0.1), 1) # For 'around' if quantifier == "at least": return actual_count >= N elif quantifier == "at most": return actual_count <= N - elif quantifier == "around": # Using exact for 'around' with capital words unless specified - return abs(actual_count - N) <= tolerance # Or just actual_count == N + elif ( + quantifier == "around" + ): # Using exact for 'around' with capital words unless specified + return abs(actual_count - N) <= tolerance # Or just actual_count == N return False + # End Checker: Finish your response with this exact phrase {end phrase}. def validate_end(text: str, end_phrase: str) -> bool: # Normalize whitespace at the end of text for robustness return text.strip().endswith(end_phrase.strip()) + # Quotation: Wrap your entire response with double quotation marks. def validate_quotation(text: str) -> bool: stripped_text = text.strip() return stripped_text.startswith('"') and stripped_text.endswith('"') + # No Commas: In your entire response, refrain from the use of any commas. def validate_no_commas(text: str) -> bool: return "," not in text + IF_FUNCTIONS_MAP = { "verify_keywords": verify_keywords, "verify_keyword_frequency": verify_keyword_frequency,