diff --git a/environments/README.md b/environments/README.md index 923d6b53..6d34ebb0 100644 --- a/environments/README.md +++ b/environments/README.md @@ -93,7 +93,9 @@ You are a deep thinking AI, you may use extremely long chains of thought to deep ### Instruction Following Environment (`instruction_following_algorithm_environment.py`) Environment for training models to follow natural language instructions and constraints, based on the `allenai/RLVR-IFeval` dataset and environment. -*This environment has unique dependencies: `datasets` (from Hugging Face) and `langdetect`.* +**Dependencies:** +- `datasets` (Hugging Face) +- `langdetect` **Input Format:** - Each item from the processed `allenai/RLVR-IFeval` dataset contains: @@ -120,7 +122,7 @@ You are a deep thinking AI, you may use extremely long chains of thought to deep - `dataset_config_name`: Optional name for a specific configuration or subset of the dataset. - `test_set_ratio`: Defines the proportion of the dataset reserved for testing (defaults to 5%). -- **Verifier-Based Scoring:** Utilizes a comprehensive map of verifier functions (`IF_FUNCTIONS_MAP`) to evaluate whether the model's +- **Verifier-Based Scoring:** Utilizes a comprehensive map of verifier functions (`IF_FUNCTIONS_MAP`) to evaluate whether the model's output adheres to diverse and specific constraints defined in the input instructions (e.g., keyword presence, response length, JSON format, etc.). - **Specialized Dataset Processing:** The `setup` method is specifically designed to parse the `allenai/RLVR-IFeval` dataset, extracting user instructions, the corresponding verifier function name, and its arguments. diff --git a/environments/instruction_following_algorithm_environment.py b/environments/instruction_following_algorithm_environment.py index 19aaadc7..8735f3f8 100644 --- a/environments/instruction_following_algorithm_environment.py +++ b/environments/instruction_following_algorithm_environment.py @@ -4,9 +4,9 @@ import re from typing import Dict, List, Optional, Tuple import wandb -from datasets import Dataset, load_dataset # Added Dataset for dummy data +from datasets import Dataset, load_dataset from langdetect import LangDetectException, detect -from pydantic import Field # Added import for Field +from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( @@ -39,11 +39,11 @@ class IFConfig(BaseEnvConfig): class InstructionFollowingEnv(BaseEnv): - env_config_cls = IFConfig # Added env_config_cls for IFConfig + env_config_cls = IFConfig def __init__( self, - config: IFConfig, # Changed BaseEnvConfig to IFConfig + config: IFConfig, server_configs: List[APIServerConfig], slurm=True, testing=False, @@ -52,14 +52,13 @@ class InstructionFollowingEnv(BaseEnv): self.percent_correct_buffer = list() self.eval_metrics = list() self.rollouts_for_wandb = [] - # self.completion_lengths = [] # Kept from SingleToolCallingEnv, assess utility @classmethod def config_init( self, - ) -> Tuple[IFConfig, List[APIServerConfig]]: # Changed BaseEnvConfig to IFConfig + ) -> Tuple[IFConfig, List[APIServerConfig]]: # Configuration for the Instruction Following Environment - env_config = IFConfig( # Changed BaseEnvConfig to IFConfig + env_config = IFConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=32, use_wandb=True, @@ -130,7 +129,7 @@ class InstructionFollowingEnv(BaseEnv): - 'prompt': The user's instruction string. - 'func_name': The string name of the verifier function. - 'args': A dictionary of arguments for that verifier function. - """ + """ # noqa: E501 dataset_name = getattr(self.config, "dataset_name", "allenai/RLVR-IFeval") dataset_config_name = getattr( self.config, "dataset_config_name", None @@ -231,7 +230,7 @@ class InstructionFollowingEnv(BaseEnv): "original_constraints_for_logging": str( item.get("constraint", "") ), # For logging, from RLVR-IFeval structure - "expected_response_for_logging": "", # RLVR-IFeval doesn't seem to have a sample good response directly # noqa: E501 + "expected_response_for_logging": "", } ) @@ -251,7 +250,8 @@ class InstructionFollowingEnv(BaseEnv): except Exception as e: # This block is a fallback if the primary dataset loading/processing fails catastrophically. - # For RLVR-IFeval, a failure here suggests issues with Hugging Face access, dataset integrity, or fundamental code errors. # noqa: E501 + # For RLVR-IFeval, a failure here suggests issues with Hugging Face access, + # dataset integrity, or fundamental code errors. print( f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. " f"Using a DUMMY dataset as fallback." @@ -316,12 +316,14 @@ class InstructionFollowingEnv(BaseEnv): # Final check for empty train/test after split, should not happen with logic above if num_items > 0 if len(self.train) == 0 and len(self.test) > 0: print( - "Warning: Train set empty after split, test set has data. This is unusual. Swapping." + "Warning: Train set empty after split, test set has data. " + "This is unusual. Swapping." ) self.train = self.test # Fallback, though indicates issue elif len(self.test) == 0 and len(self.train) > 0: print( - "Warning: Test set empty after split, train set has data. Using full train set for test as well." + "Warning: Test set empty after split, train set has data. " + "Using full train set for test as well." ) self.test = self.train @@ -339,41 +341,40 @@ class InstructionFollowingEnv(BaseEnv): # 1. Count and tags num_think_open = len(re.findall(r"", model_response_text, re.IGNORECASE)) - num_think_close = len(re.findall(r"", model_response_text, re.IGNORECASE)) + num_think_close = len( + re.findall(r"", model_response_text, re.IGNORECASE) + ) - # 2. If counts are not exactly one for each, malformed. if not (num_think_open == 1 and num_think_close == 1): - # Optionally, add a debug print statement here if needed for logging - # print(f"DEBUG: Malformed think tags. Open: {num_think_open}, Close: {num_think_close}. Response: '{model_response_text[:200]}...'") return 0.0 # 3. Find the first occurrence of and try: think_open_match = re.search(r"", model_response_text, re.IGNORECASE) - think_close_match = re.search(r"", model_response_text, re.IGNORECASE) - + think_close_match = re.search( + r"", model_response_text, re.IGNORECASE + ) + # These should exist due to the count check, but access .start() and .end() safely idx_think_open = think_open_match.start() idx_think_close_start = think_close_match.start() idx_think_close_end = think_close_match.end() - except AttributeError: - # This case should ideally be caught by the count check, but as a fallback: - # print(f"DEBUG: Could not find start/end of think/close tags despite counts. Response: '{model_response_text[:200]}...'") + except AttributeError: return 0.0 # 4. If appears after , malformed. if idx_think_open >= idx_think_close_start: # print(f"DEBUG: tag appears at or after tag. Response: '{model_response_text[:200]}...'") return 0.0 - + # 5. Extract text_to_verify (content after the first ) text_to_verify = model_response_text[idx_think_close_end:].strip() # 6. Check if text_to_verify itself contains any further or tags. - if re.search(r"", text_to_verify, re.IGNORECASE) or \ - re.search(r"", text_to_verify, re.IGNORECASE): - # print(f"DEBUG: Found or tags in the answer part. Answer part: '{text_to_verify[:200]}...'") + if re.search(r"", text_to_verify, re.IGNORECASE) or re.search( + r"", text_to_verify, re.IGNORECASE + ): return 0.0 # If all checks pass, proceed with verification using text_to_verify @@ -397,6 +398,7 @@ class InstructionFollowingEnv(BaseEnv): ) else: from inspect import signature + sig = signature(verifier_func) valid_params = [p for p in sig.parameters if p != "text"] filtered_args = { @@ -446,7 +448,9 @@ class InstructionFollowingEnv(BaseEnv): func_name = test_item["func_name"] args_for_verifier = test_item["args"] - print(f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}...") # DEBUG + print( + f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}..." + ) # DEBUG messages = [{"role": "system", "content": system_prompt}] messages.append({"role": "user", "content": instruction_prompt_text}) @@ -455,7 +459,9 @@ class InstructionFollowingEnv(BaseEnv): messages, add_generation_prompt=True, tokenize=False ) - print(f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}...") # DEBUG + print( + f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}..." + ) # DEBUG completion = await self.server.completion( prompt=prompt_str, n=1, @@ -463,7 +469,7 @@ class InstructionFollowingEnv(BaseEnv): temperature=0.2, # Temperature for eval, can be 0 for deterministic split="eval", ) - print(f"DEBUG: Received completion in rollout_and_score_eval.") # DEBUG + print("DEBUG: Received completion in rollout_and_score_eval.") # DEBUG model_response_text = completion.choices[0].text score_value = await self._get_score_from_verifier( @@ -481,7 +487,7 @@ class InstructionFollowingEnv(BaseEnv): self.eval_metrics.append(("eval/percent_correct", 0.0)) return - print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG + print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG eval_tasks = [] for test_item_dict in self.test: # self.test contains dicts after setup eval_tasks.append(self.rollout_and_score_eval(test_item_dict)) @@ -501,7 +507,7 @@ class InstructionFollowingEnv(BaseEnv): ) -> Tuple[Optional[ScoredDataGroup], List]: # item = (prompt_messages_tuple, answer_info_dict) # answer_info_dict = {"func_name": ..., "args": ...} - print(f"DEBUG: Entering collect_trajectories. Item: {item}") # DEBUG + print(f"DEBUG: Entering collect_trajectories. Item: {str(item)}") # DEBUG prompt_messages_list = [dict(msg_fset) for msg_fset in item[0]] answer_info = item[1] @@ -509,7 +515,9 @@ class InstructionFollowingEnv(BaseEnv): prompt_messages_list, add_generation_prompt=True, tokenize=False ) - print(f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}...") # DEBUG + print( + f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}..." + ) # DEBUG try: completions = await self.server.completion( prompt=prompt_str, @@ -517,11 +525,15 @@ class InstructionFollowingEnv(BaseEnv): max_tokens=self.config.max_token_length, temperature=0.8, # Temperature for diverse responses during training rollouts ) - print(f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories.") # DEBUG + print( + f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories." + ) # DEBUG except Exception as e: - print(f"ERROR: Exception during self.server.completion in collect_trajectories: {e}") # DEBUG + print( + f"ERROR: Exception during self.server.completion in collect_trajectories: {e}" + ) # DEBUG # Depending on the desired behavior, you might want to return None or raise the exception - return None, [] + return None, [] to_score_list = [] for choice in completions.choices: @@ -534,11 +546,15 @@ class InstructionFollowingEnv(BaseEnv): if not to_score_list: return None, [] - print(f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories.") # DEBUG + print( + f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories." + ) # DEBUG scored_data = await self.score(to_score_list) to_backlog = [] # Backlog not currently used but part of signature - print(f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}") # DEBUG + print( + f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}" + ) # DEBUG return scored_data, to_backlog def save_checkpoint(self, step, data=None): @@ -817,7 +833,8 @@ def validate_word_constraint(text: str, N: int, quantifier: str) -> bool: def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool: # Basic sentence splitting, might need more robust NLP for complex cases sentences = re.split( - r"(?