import json import random import re from typing import Dict, List, Optional, Tuple, Union import wandb from datasets import load_dataset, Dataset # Added Dataset for dummy data from tqdm.asyncio import tqdm_asyncio from langdetect import detect, LangDetectException from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, Item, ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer # System prompt can be reused or adapted for instruction following tasks system_prompt = ( "You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the " "problem and deliberate with yourself via systematic reasoning processes to help come to a correct " "solution prior to answering. You should enclose your thoughts and internal monologue inside " " tags, and then provide your solution or response to the problem." ) class InstructionFollowingEnv(BaseEnv): def __init__( self, config: BaseEnvConfig, server_configs: List[APIServerConfig], slurm=True, testing=False, ): super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer = list() self.eval_metrics = list() self.rollouts_for_wandb = [] # self.completion_lengths = [] # Kept from SingleToolCallingEnv, assess utility @classmethod def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: # Configuration for the Instruction Following Environment env_config = BaseEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=16, use_wandb=True, rollout_server_url="http://localhost:8000", total_steps=500, batch_size=1024, steps_per_eval=20, max_token_length=1024 * 16, inference_weight=1.0, wandb_name="instruction_following_rlvr_ifeval", # Specific WandB project name eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, dataset_name="allenai/RLVR-IFeval", # Default dataset dataset_config_name=None, # RLVR-IFeval doesn't have a specific config name, uses 'default' test_set_ratio=0.05 # The ratio of the selelcted dataset in % ) # Server configurations can be similar to SingleToolCallingEnv or adjusted server_configs = [ APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", api_key="x", num_max_requests_at_once=32, num_requests_for_eval=256, ) ] return env_config, server_configs async def create_rollout_table(self, wandb_metrics): # Logs rollouts to a WandB table for visualization if len(self.rollouts_for_wandb) > 0: table = wandb.Table(columns=["text", "score", "constraint_details"]) for group in self.rollouts_for_wandb: for item in group: # item[0] is model output, item[1] is score, item[2] is constraint info table.add_data(item[0], item[1], json.dumps(item[2])) wandb_metrics["train/rollouts"] = table self.rollouts_for_wandb = [] return wandb_metrics async def wandb_log(self, wandb_metrics: Optional[Dict] = None): # Logs metrics to WandB if wandb_metrics is None: wandb_metrics = dict() try: wandb_metrics["train/percent_correct"] = sum( self.percent_correct_buffer ) / len(self.percent_correct_buffer) except ZeroDivisionError: pass # Buffer might be empty self.percent_correct_buffer = list() for item in self.eval_metrics: wandb_metrics[item[0]] = item[1] self.eval_metrics = list() await super().wandb_log(wandb_metrics) async def setup(self): """ Load and preprocess the dataset for instruction following. This method is specifically tailored to process 'allenai/RLVR-IFeval' dataset structure. Each item from RLVR-IFeval is expected to have: - 'messages': A list of dictionaries, e.g., [{'role': 'user', 'content': 'instruction...'}] - 'ground_truth': A JSON string containing 'func_name' and arguments for the verifier. The method will parse these to produce items for the environment with: - 'prompt': The user's instruction string. - 'func_name': The string name of the verifier function. - 'args': A dictionary of arguments for that verifier function. """ dataset_name = getattr(self.config, "dataset_name", "allenai/RLVR-IFeval") dataset_config_name = getattr(self.config, "dataset_config_name", None) # Default is None, RLVR-IFeval has no sub-config processed_items = [] try: print(f"Attempting to load dataset: {dataset_name}, config: {dataset_config_name if dataset_config_name else 'default'}") if dataset_config_name: full_dataset_raw = load_dataset(dataset_name, dataset_config_name, split="train", trust_remote_code=True) else: full_dataset_raw = load_dataset(dataset_name, split="train", trust_remote_code=True) print(f"Successfully loaded raw dataset. Number of items: {len(full_dataset_raw)}") for i, item in enumerate(full_dataset_raw): # Extract prompt from 'messages' field item_messages = item.get("messages") if not item_messages or not isinstance(item_messages, list) or len(item_messages) == 0: print(f"Warning: Item {i} has invalid or empty 'messages' field. Skipping. Item: {item}") continue # Assuming the relevant prompt is the content of the first message in the list # (or last, if multiple user messages were possible, but IFEval is typically single user instruction) prompt_text = item_messages[0].get("content") if not prompt_text: print(f"Warning: Item {i} '{item_messages[0]}' has no content. Skipping.") continue # Get the ground_truth JSON string ground_truth_json_str = item.get("ground_truth") if not ground_truth_json_str or not isinstance(ground_truth_json_str, str): print(f"Warning: Item {i} missing or has invalid 'ground_truth' string. Skipping. Prompt: {prompt_text[:50]}...") continue try: parsed_gt = json.loads(ground_truth_json_str) if not isinstance(parsed_gt, dict): raise ValueError("Parsed ground_truth is not a dictionary.") except (json.JSONDecodeError, ValueError) as e: print(f"Warning: Could not parse 'ground_truth' JSON for item {i}. Error: {e}. GT String: '{ground_truth_json_str}'. Prompt: {prompt_text[:50]}... Skipping.") continue func_name_from_gt = parsed_gt.get("func_name") if not func_name_from_gt: print(f"Warning: Item {i} parsed 'ground_truth' has no 'func_name'. GT: {parsed_gt}. Prompt: {prompt_text[:50]}... Skipping.") continue if func_name_from_gt not in IF_FUNCTIONS_MAP: print(f"Warning: func_name '{func_name_from_gt}' in item {i} not in IF_FUNCTIONS_MAP. Prompt: {prompt_text[:50]}... Skipping.") continue # Prepare args for the verifier function: remove func_name and keep others. # Verifier functions will only use args they expect. args_dict = {k: v for k, v in parsed_gt.items() if k != "func_name" and v is not None} processed_items.append({ "prompt": prompt_text, "func_name": func_name_from_gt, "args": args_dict, "original_constraints_for_logging": str(item.get("constraint", "")), # For logging, from RLVR-IFeval structure "expected_response_for_logging": "" # RLVR-IFeval doesn't seem to have a sample good response directly }) if not processed_items: print("Warning: No items successfully processed from the dataset. Check dataset format/content or parsing logic.") raise ValueError("Dataset processing resulted in no valid items for RLVR-IFeval. Cannot proceed without data.") full_dataset = Dataset.from_list(processed_items) print(f"Successfully processed {len(full_dataset)} items from dataset '{dataset_name}'.") except Exception as e: # This block is a fallback if the primary dataset loading/processing fails catastrophically. # For RLVR-IFeval, a failure here suggests issues with Hugging Face access, dataset integrity, or fundamental code errors. print(f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. Using a DUMMY dataset as fallback.") dummy_data_for_fallback = [ { "prompt": "Dummy Instruction 1: Ensure your response contains the word 'example'.", "func_name": "verify_keywords", "args": {"keyword_list": ["example"]}, "original_constraints_for_logging": "Contains 'example'", "expected_response_for_logging": "This is an example response." }, { "prompt": "Dummy Instruction 2: Output a valid JSON with key 'data' and value 'test'.", "func_name": "validate_json_format", "args": {}, "original_constraints_for_logging": "Output valid JSON.", "expected_response_for_logging": "{\\\"data\\\": \\\"test\\\"}" } ] full_dataset = Dataset.from_list(dummy_data_for_fallback) print(f"Initialized with DUMMY dataset of {len(full_dataset)} items due to previous errors.") full_dataset = full_dataset.shuffle(seed=42) actual_test_size = self.config.test_set_ratio # Read from config num_items = len(full_dataset) if num_items == 0: print("ERROR: Dataset is empty. Cannot create train/test split.") self.train = Dataset.from_list([]) self.test = Dataset.from_list([]) elif num_items == 1: print("Warning: Dataset has only 1 item. Using it for both train and test.") self.train = full_dataset self.test = full_dataset else: # num_items > 1 # Ensure test_size results in at least 1 item for test set if possible, but not more than train set if num_items < 5 : # For 2,3,4 items, make test size 1 min_test_items = 1 else: # For 5+ items, 20% is fine min_test_items = max(1, int(num_items * actual_test_size)) # Ensure test split is not too large, e.g. not more than 50% unless dataset is very small # And ensure train always has at least one item if num_items > 1 calculated_test_size = min_test_items / num_items if calculated_test_size >= 0.5 and num_items > 2: # If test is 50% or more and we have 3+ items calculated_test_size = (num_items -1) / num_items # Make train have at least 1 split_dataset = full_dataset.train_test_split(test_size=calculated_test_size, seed=42) self.train = split_dataset["train"] self.test = split_dataset["test"] # Final check for empty train/test after split, should not happen with logic above if num_items > 0 if len(self.train) == 0 and len(self.test) > 0: print("Warning: Train set empty after split, test set has data. This is unusual. Swapping.") self.train = self.test # Fallback, though indicates issue elif len(self.test) == 0 and len(self.train) > 0: print("Warning: Test set empty after split, train set has data. Using full train set for test as well.") self.test = self.train self.iter = 0 print(f"Dataset setup complete. Train size: {len(self.train)}, Test size: {len(self.test)}") async def _get_score_from_verifier(self, model_response_text: str, func_name: str, args: Dict) -> float: """Helper to call verifier function and get a numerical score.""" if func_name not in IF_FUNCTIONS_MAP: print(f"Warning: Verifier function '{func_name}' not found in IF_FUNCTIONS_MAP.") return 0.0 verifier_func = IF_FUNCTIONS_MAP[func_name] raw_score = None try: # For validate_response_language, langdetect is now imported at the top. # Specific argument handling for functions that don't take generic **args # or have special return types should be done before the generic call. if func_name == "validate_placeholders": # validate_placeholders expects 'text' and 'N', returns (bool, list) raw_score = verifier_func(model_response_text, N=args.get("N")) elif func_name == "verify_bullet_points": # verify_bullet_points expects 'text' and 'N', returns bool (or was (bool,str) in one doc) # Assuming it now consistently returns bool as per our integrated version raw_score = verifier_func(model_response_text, N=args.get("N")) # Add other specific handlers here if necessary, otherwise use generic **args else: raw_score = verifier_func(model_response_text, **args) except LangDetectException: # Specifically catch for language detection issues print(f"Warning: langdetect failed for func_name '{func_name}'. Scoring as incorrect.") return 0.0 except ImportError as e: # This might happen if a function tries a lazy import that fails (not langdetect now) print(f"Warning: ImportError during verifier function '{func_name}': {e}. Check dependencies.") return 0.0 except TypeError as e: print(f"TypeError calling {func_name} with args {args}: {e}. Text: '{model_response_text[:100]}...'") return 0.0 except Exception as e: # Catch any other unexpected error from a verifier print(f"Unexpected error in verifier function '{func_name}' with args {args}: {e}") return 0.0 # Convert boolean or tuple[boolean, ...] to float score if isinstance(raw_score, tuple): score_value = float(raw_score[0]) # Assuming the first element is the boolean score elif isinstance(raw_score, bool): score_value = float(raw_score) else: print(f"Warning: Verifier '{func_name}' returned unexpected type: {type(raw_score)}. Expected bool or tuple.") score_value = 0.0 return score_value async def rollout_and_score_eval(self, test_item: Dict): # test_item is a dictionary from the test set, processed by setup() # It should contain 'prompt', 'func_name', 'args' instruction_prompt_text = test_item["prompt"] func_name = test_item["func_name"] args_for_verifier = test_item["args"] messages = [{"role": "system", "content": system_prompt}] messages.append({"role": "user", "content": instruction_prompt_text}) prompt_str = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) completion = await self.server.completion( prompt=prompt_str, n=1, max_tokens=self.config.max_token_length, # Use config for max_tokens temperature=0.7, # Temperature for eval, can be 0 for deterministic split="eval", ) model_response_text = completion.choices[0].text score_value = await self._get_score_from_verifier(model_response_text, func_name, args_for_verifier) return score_value # Returns 1.0 for correct, 0.0 for incorrect based on verifier async def evaluate(self, *args, **kwargs): # Evaluates the model on the test set if not self.test or len(self.test) == 0: print("Warning: Test set is empty. Skipping evaluation.") self.eval_metrics.append(("eval/percent_correct", 0.0)) return eval_tasks = [] for test_item_dict in self.test: # self.test contains dicts after setup eval_tasks.append(self.rollout_and_score_eval(test_item_dict)) scores = await tqdm_asyncio.gather(*eval_tasks) if not scores: # If gather returns empty list percent_correct = 0.0 else: percent_correct = sum(scores) / len(scores) self.eval_metrics.append(("eval/percent_correct", percent_correct)) print(f"Evaluation percent correct: {percent_correct}") async def collect_trajectories(self, item: Item) -> Tuple[Optional[ScoredDataGroup], List]: # item = (prompt_messages_tuple, answer_info_dict) # answer_info_dict = {"func_name": ..., "args": ...} prompt_messages_list = [dict(msg_fset) for msg_fset in item[0]] answer_info = item[1] prompt_str = self.tokenizer.apply_chat_template( prompt_messages_list, add_generation_prompt=True, tokenize=False ) completions = await self.server.completion( prompt=prompt_str, n=self.config.group_size, max_tokens=self.config.max_token_length, temperature=0.8, # Temperature for diverse responses during training rollouts ) to_score_list = [] for choice in completions.choices: trajectory_messages = [dict(msg_fset) for msg_fset in item[0]] # Fresh copy trajectory_messages.append({"role": "assistant", "content": choice.text}) to_score_list.append((tuple(trajectory_messages), answer_info)) # Pass answer_info if not to_score_list: return None, [] scored_data = await self.score(to_score_list) to_backlog = [] # Backlog not currently used but part of signature return scored_data, to_backlog def save_checkpoint(self, step, data=None): if data is None: data = {} data["iter"] = self.iter super().save_checkpoint(step, data) async def score( self, rollout_group_data: List[Tuple[tuple, Dict]] ) -> Optional[ScoredDataGroup]: # rollout_group_data is a list of (trajectory_messages_tuple, answer_info_dict) # answer_info_dict = {"func_name": ..., "args": ...} scores_container = ScoredDataGroup() scores_container["tokens"] = list() scores_container["masks"] = list() scores_container["scores"] = list() if not rollout_group_data: return None # The 'answer_info' (func_name, args) is consistent for all items in this group, # as it comes from the same initial prompt. # We can extract it once if needed, but it's passed per item. random.shuffle(rollout_group_data) # Shuffle to avoid bias for trajectory_item in rollout_group_data: full_trajectory_messages = trajectory_item[0] answer_info = trajectory_item[1] # {"func_name": ..., "args": ...} model_response_text = full_trajectory_messages[-1]["content"] func_name = answer_info["func_name"] args_for_verifier = answer_info["args"] # Get score (1.0 for correct, 0.0 for incorrect from verifier) score_value = await self._get_score_from_verifier(model_response_text, func_name, args_for_verifier) # Map to reward: 1.0 for correct, -1.0 for incorrect reward = 1.0 if score_value == 1.0 else -1.0 # Tokenize the conversation for PPO training # Ensure full_trajectory_messages is a list of dicts list_of_dicts_trajectory = [dict(msg) for msg in full_trajectory_messages] out_dict = tokenize_for_trainer(self.tokenizer, list_of_dicts_trajectory) tokens = out_dict["tokens"] masks = out_dict["masks"] # Filter out examples with insufficient context (too short) if sum(1 for m_val in masks if m_val != -100) < 10: # At least 10 non-masked tokens continue scores_container["tokens"].append(tokens) scores_container["masks"].append(masks) scores_container["scores"].append(reward) # Stop if we have enough examples for the group if len(scores_container["tokens"]) >= self.config.group_size: break if not scores_container["tokens"]: # No valid items collected return None # Record success rate for logging (based on positive rewards) for rwd in scores_container["scores"]: self.percent_correct_buffer.append(max(0, rwd)) # If reward is 1.0, it's a success # Optional: Apply length penalty if all responses are correct (reward 1.0) # This logic is from SingleToolCallingEnv, may need adjustment for IF if all(s == 1.0 for s in scores_container["scores"]): token_lengths = [len(t) for t in scores_container["tokens"]] if not token_lengths or max(token_lengths) == 0: return scores_container # Avoid division by zero, or if all empty max_allowed_length = self.config.max_token_length # Threshold can be adjusted, e.g., 75% of max_token_length length_threshold = max_allowed_length * 0.75 penalized_scores = [] for i, length in enumerate(token_lengths): original_score = scores_container["scores"][i] # Should be 1.0 here if length <= length_threshold: penalized_scores.append(original_score) else: # Linear penalty for exceeding threshold penalty_factor = (length - length_threshold) / (max_allowed_length - length_threshold) penalty_factor = min(penalty_factor, 1.0) # Cap penalty factor at 1 # Penalized score scales from original_score down to original_score * (1-1) = 0 penalized_scores.append(original_score * (1.0 - penalty_factor)) scores_container["scores"] = penalized_scores # If all scores are identical after potential penalties, no learning signal if len(set(scores_container["scores"])) <= 1 and len(scores_container["scores"]) > 1 : return None # Avoid sending data with no variance return scores_container async def get_next_item(self) -> Item: # Fetches the next preprocessed item from the training set if not self.train or len(self.train) == 0: # This case should be handled by setup, but as a safeguard: print("Error: Training data is empty in get_next_item.") # Return a dummy item to prevent crashes, though this indicates a setup issue dummy_prompt_messages = ( frozenset({"role": "system", "content": system_prompt}.items()), frozenset({"role": "user", "content": "Dummy instruction: say hello."}.items()) ) dummy_answer_info = {"func_name": "verify_keywords", "args": {"keyword_list": ["hello"]}} return (dummy_prompt_messages, dummy_answer_info) raw_item = self.train[self.iter % len(self.train)] # raw_item is a dict self.iter += 1 instruction_prompt_text = raw_item["prompt"] # Construct messages for the LLM (prompt tuple part of Item) # Using frozenset as required by BaseEnv's Item type hint prompt_messages_tuple = ( frozenset({"role": "system", "content": system_prompt}.items()), frozenset({"role": "user", "content": instruction_prompt_text}.items()), ) # The "answer" part for scoring purposes (answer_info dict part of Item) answer_info = { "func_name": raw_item["func_name"], "args": raw_item["args"], # Optionally include other info for logging/debugging if needed from raw_item "original_constraints_for_logging": raw_item.get("original_constraints", ""), "expected_response_for_logging": raw_item.get("expected_response_for_logging", "") } return (prompt_messages_tuple, answer_info) async def add_rollouts_for_wandb( self, scored_data: ScoredDataGroup, # Assuming single ScoredDataGroup here item: Item = None, # item = (prompt_messages_tuple, answer_info_dict) ): # Saves rollouts for WandB logging num_keep = self.config.num_rollouts_per_group_for_logging if num_keep == -1: # Log all rollouts in the group num_keep = len(scored_data["tokens"]) # item[1] is the answer_info_dict containing func_name and args constraint_details_for_log = item[1] if item else {} rollout_batch = [] for i in range(min(num_keep, len(scored_data["tokens"]))): decoded_text = self.tokenizer.decode(scored_data["tokens"][i], skip_special_tokens=True) score = scored_data["scores"][i] rollout_batch.append((decoded_text, score, constraint_details_for_log)) self.rollouts_for_wandb.append(rollout_batch) # Limit the number of rollout groups stored if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: self.rollouts_for_wandb.pop(0) # ----- IFEval Verifier Functions and Map ----- # adapted from https://github.com/allenai/open-instruct/blob/main/scripts/eval_constraints/if_functions.py # Helper function for verify_keyword_frequency, moved import re to top level def _extract_words(text: str) -> List[str]: return re.findall(r"\\b\\w+\\b", text.lower()) # include keywords: Include keywords {keyword1}, {keyword2} in your response def verify_keywords(text: str, keyword_list: List[str]) -> bool: response_lower = text.lower() return all(keyword.lower() in response_lower for keyword in keyword_list) # Keyword Frequency: In your response, the word {word} should appear {N} times. def verify_keyword_frequency(text: str, word: str, N: int) -> bool: text_lower = text.lower() keyword_lower = word.lower() words = _extract_words(text_lower) actual_count = sum(1 for w in words if w == keyword_lower) return actual_count == N # Forbidden Words: Do not include keywords {forbidden words} in the response. def validate_forbidden_words(text: str, forbidden_words: List[str]) -> bool: text_lower = text.lower() return not any(word.lower() in text_lower for word in forbidden_words) # Letter Frequency : In your response, the letter {letter} should appear {N} times. def verify_letter_frequency(text: str, letter: str, N: int) -> bool: if len(letter) != 1: # This should ideally raise ValueError, but for RL reward, return False return False actual_count = text.count(letter) return actual_count == N # Response Language: Your ENTIRE response should be in {language}, no other language is allowed. def validate_response_language(text: str, language: str) -> bool: try: detected_language = detect(text) return detected_language == language except LangDetectException: # Catching specific exception from detect() print(f"Warning: langdetect failed to detect language for text: '{text[:50]}...'") return False # Number Paragraphs: Your response should contain {N} paragraphs. You separate paragraphs using the markdown divider: # * * * def verify_paragraph_count(text: str, N: int) -> bool: def clean_text(txt: str) -> str: return "\\n".join(line.strip() for line in txt.splitlines()).strip() cleaned_text = clean_text(text) # Paragraphs are separated by '* * *'. N dividers mean N+1 paragraphs. # If the text IS paragraphs, then N paragraphs will have N-1 dividers. # The prompt implies N paragraphs are expected. # If N=1, 0 dividers. If N=2, 1 divider. So, count of parts = N. paragraphs = cleaned_text.split("* * *") actual_count = len(paragraphs) # Verify each split resulted in non-empty content, if text itself is not empty if not cleaned_text and N == 0 : return True # 0 paragraphs, empty text if not cleaned_text and N > 0 : return False valid_paragraphs = [p.strip() for p in paragraphs if p.strip()] # This check might be too strict if empty paragraphs are allowed by the constraint definition # If "paragraph" implies non-empty content: # return len(valid_paragraphs) == N and actual_count == N # If constraint just means N segments separated by dividers: return actual_count == N # Number Words: Answer with at least / around / at most {N} words def validate_word_constraint(text: str, N: int, quantifier: str) -> bool: words = text.strip().split() actual_count = len(words) tolerance = max(round(N * 0.1), 1) # For 'around' if quantifier == "at least": return actual_count >= N elif quantifier == "at most": return actual_count <= N elif quantifier == "around": return abs(actual_count - N) <= tolerance return False # Number Sentences: Answer with at least / around / at most {N} sentences. def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool: # Basic sentence splitting, might need more robust NLP for complex cases sentences = re.split(r'(?= N elif quantifier == "around": # "around" for sentences usually means exact or +/-1 return abs(actual_count - N) <= 1 elif quantifier == "at most": return actual_count <= N return False # Number Paragraphs + First Word in i-th Paragraph def validate_paragraphs(text: str, N: int, first_word: str, i: int) -> bool: # Paragraphs separated by double line breaks paragraphs = text.split("\\n\\n") if len(paragraphs) != N: return False # i is 1-indexed for paragraph number if not (1 <= i <= len(paragraphs)): return False # Check first word of the i-th paragraph # .strip() to handle leading/trailing whitespace in paragraph # .split()[0] to get the first word try: actual_first_word = paragraphs[i - 1].strip().split()[0] # Case-insensitive comparison for first_word might be more robust return actual_first_word.lower() == first_word.lower() except IndexError: # Handles empty paragraph or paragraph without words return False # Postscript: At the end of your response, please explicitly add a postscript starting with {postscript marker} def verify_postscript(text: str, postscript_marker: str) -> bool: marker_index = text.rfind(postscript_marker) # Find last occurrence if marker_index == -1: return False # Check if it's truly a postscript (i.e., near the end, and has content after marker) # This interpretation: marker exists, and something follows it OR it's at the very end. # The original IFEval might have a stricter definition (e.g. specific distance from end) # A simple check: marker is present and the text from marker to end is mostly the postscript. # For RL, simpler: marker is present and is not just prefix of a word. # Test if the marker is at a word boundary if it's not the start of the string if marker_index > 0 and text[marker_index-1].isalnum() and postscript_marker[0].isalnum(): # Avoid matching mid-word, e.g. "script" in "postscript" if marker is "script" # This check is heuristic. A regex with word boundaries might be better. pass # Heuristic, might need refinement # Check if content exists after marker, or if marker itself is the end remaining_text = text[marker_index:].strip() return len(remaining_text) >= len(postscript_marker.strip()) # Number Placeholder: The response must contain at least {N} placeholders ... [address]. def validate_placeholders(text: str, N: int) -> Tuple[bool, List[str]]: placeholders_found = re.findall(r'\\[(.*?)\\]', text) # Matches [content] return len(placeholders_found) >= N, placeholders_found # Number Bullets: Your answer must contain exactly {N} bullet points. * This is a point. def verify_bullet_points(text: str, N: int) -> bool: # Original had tuple[bool,str] in doc, bool in code lines = text.splitlines() # Markdown bullets usually start with '*', '-', or '+' followed by a space. bullet_points = [line.strip() for line in lines if re.match(r'^(\\s*)[\\*\\-\\+]\\s+', line.strip())] return len(bullet_points) == N # Title: Your answer must contain a title, wrapped in double angular brackets, such as <>. def validate_title(text: str) -> bool: return bool(re.search(r'<<(.*?)>>', text)) # Choose: From Answer with one of the following options: {options} def validate_choice(text: str, options: List[str]) -> bool: # Assuming 'text' should be one of the 'options' exactly, or contain one of them. # The original prompt "Answer with one of..." implies the response *is* one of the options. # Case-insensitive comparison for robustness. text_cleaned = text.strip().lower() return any(text_cleaned == opt.strip().lower() for opt in options) # Minimum Number Highlighted Section: Highlight at least {N} sections ... *highlighted section* def validate_highlighted_sections(text: str, N: int) -> bool: # Markdown italics/bold *highlight* or **highlight** # This regex looks for single asterisks: *content* matches = re.findall(r'\\*(.*?)\\*(?= N # Multiple Sections: Your response must have {N} sections. Mark ... with {section splitter} X. def validate_sections(text: str, N: int, section_splitter: str) -> bool: # Example: section_splitter = "Section" -> "Section 1", "Section 2" # This implies the splitter itself might include a number or be just the prefix. # If splitter is "---", then text.split("---"). # If splitter is "Topic X:", this is more complex. # Assuming a simple string split is intended by the original IFEval function. # The prompt phrasing "Mark the beginning of each section with {section splitter} X" # suggests counting occurrences of the splitter pattern. # If section_splitter is like "SECTION", we'd look for "SECTION 1", "SECTION 2", ... # This is hard to generalize perfectly without knowing how IFEval defines 'X'. # Simplest: count occurrences of the base splitter string. # sections = text.split(section_splitter) # num_sections = len(sections) -1 if sections[0].strip() == "" else len(sections) # A slightly more robust way for "Splitter X": # Count how many times "splitter" followed by something (like a number) appears. # Example: if splitter is "Chapter", we look for "Chapter 1", "Chapter ...". # This regex is a placeholder for more specific logic IFEval might use. # Let's use a simple count of the splitter string for now. # This might need to be adjusted based on IFEval's exact expectation for "X". # For "SECTION 1.", "SECTION 2.", if splitter is "SECTION ": actual_sections = len(re.findall(re.escape(section_splitter) + r'\\s*\\d*[:\\.\\s]', text, re.IGNORECASE)) # If N=0 and no splitters, it's true. If N>0 and no splitters, false. if N == 0: return actual_sections == 0 return actual_sections == N # JSON Format : Entire output should be wrapped in JSON format. def validate_json_format(text: str) -> bool: try: json.loads(text.strip()) # .strip() to handle leading/trailing whitespace return True except json.JSONDecodeError: return False # Repeat Prompt: First, repeat the request without change, then give your answer def validate_repeat_prompt(text: str, original_prompt: str) -> bool: # Normalize whitespace for comparison robustness text_norm = " ".join(text.strip().split()) original_prompt_norm = " ".join(original_prompt.strip().split()) return text_norm.startswith(original_prompt_norm) # Two Responses: Give two different responses. Separated by 6 asterisk symbols: ******. def validate_two_responses(text: str) -> bool: if text.count("******") == 1: parts = text.split("******") if len(parts) == 2: # Check if parts are non-empty and different resp1 = parts[0].strip() resp2 = parts[1].strip() return bool(resp1 and resp2 and resp1 != resp2) return False # All Uppercase: Your entire response should be in English, capital letters only. def validate_uppercase(text: str) -> bool: # Check if it has letters and all letters are uppercase if not any(c.isalpha() for c in text): # No letters, technically not violating "all capital" return True # Or False, depending on interpretation of "response" return text == text.upper() # All Lowercase: Your entire response should be in English, and in all lowercase letters. def validate_lowercase(text: str) -> bool: if not any(c.isalpha() for c in text): return True return text == text.lower() # Frequency of All-capital Words def validate_frequency_capital_words(text: str, N: int, quantifier: str) -> bool: # Words with all capital letters, e.g., "NASA", "AI". Min 2 chars to be a "word". capital_words = re.findall(r'\\b[A-Z]{2,}\\b', text) actual_count = len(capital_words) tolerance = max(round(N * 0.1), 1) # For 'around' if quantifier == "at least": return actual_count >= N elif quantifier == "at most": return actual_count <= N elif quantifier == "around": # Using exact for 'around' with capital words unless specified return abs(actual_count - N) <= tolerance # Or just actual_count == N return False # End Checker: Finish your response with this exact phrase {end phrase}. def validate_end(text: str, end_phrase: str) -> bool: # Normalize whitespace at the end of text for robustness return text.strip().endswith(end_phrase.strip()) # Quotation: Wrap your entire response with double quotation marks. def validate_quotation(text: str) -> bool: stripped_text = text.strip() return stripped_text.startswith('"') and stripped_text.endswith('"') # No Commas: In your entire response, refrain from the use of any commas. def validate_no_commas(text: str) -> bool: return "," not in text IF_FUNCTIONS_MAP = { "verify_keywords": verify_keywords, "verify_keyword_frequency": verify_keyword_frequency, "validate_forbidden_words": validate_forbidden_words, "verify_letter_frequency": verify_letter_frequency, "validate_response_language": validate_response_language, "verify_paragraph_count": verify_paragraph_count, "validate_word_constraint": validate_word_constraint, "verify_sentence_constraint": verify_sentence_constraint, "validate_paragraphs": validate_paragraphs, "verify_postscript": verify_postscript, "validate_placeholders": validate_placeholders, "verify_bullet_points": verify_bullet_points, "validate_title": validate_title, "validate_choice": validate_choice, "validate_highlighted_sections": validate_highlighted_sections, "validate_sections": validate_sections, "validate_json_format": validate_json_format, "validate_repeat_prompt": validate_repeat_prompt, "validate_two_responses": validate_two_responses, "validate_uppercase": validate_uppercase, "validate_lowercase": validate_lowercase, "validate_frequency_capital_words": validate_frequency_capital_words, "validate_end": validate_end, "validate_quotation": validate_quotation, "validate_no_commas": validate_no_commas, } if __name__ == "__main__": InstructionFollowingEnv.cli()