diff --git a/environments/README.md b/environments/README.md
index 923d6b53..6d34ebb0 100644
--- a/environments/README.md
+++ b/environments/README.md
@@ -93,7 +93,9 @@ You are a deep thinking AI, you may use extremely long chains of thought to deep
### Instruction Following Environment (`instruction_following_algorithm_environment.py`)
Environment for training models to follow natural language instructions and constraints, based on the `allenai/RLVR-IFeval` dataset and environment.
-*This environment has unique dependencies: `datasets` (from Hugging Face) and `langdetect`.*
+**Dependencies:**
+- `datasets` (Hugging Face)
+- `langdetect`
**Input Format:**
- Each item from the processed `allenai/RLVR-IFeval` dataset contains:
@@ -120,7 +122,7 @@ You are a deep thinking AI, you may use extremely long chains of thought to deep
- `dataset_config_name`: Optional name for a specific configuration or subset of the dataset.
- `test_set_ratio`: Defines the proportion of the dataset reserved for testing (defaults to 5%).
-- **Verifier-Based Scoring:** Utilizes a comprehensive map of verifier functions (`IF_FUNCTIONS_MAP`) to evaluate whether the model's
+- **Verifier-Based Scoring:** Utilizes a comprehensive map of verifier functions (`IF_FUNCTIONS_MAP`) to evaluate whether the model's
output adheres to diverse and specific constraints defined in the input instructions (e.g., keyword presence, response length, JSON format, etc.).
- **Specialized Dataset Processing:** The `setup` method is specifically designed to parse the `allenai/RLVR-IFeval` dataset, extracting user instructions, the corresponding verifier function name, and its arguments.
diff --git a/environments/instruction_following_algorithm_environment.py b/environments/instruction_following_algorithm_environment.py
index 19aaadc7..8735f3f8 100644
--- a/environments/instruction_following_algorithm_environment.py
+++ b/environments/instruction_following_algorithm_environment.py
@@ -4,9 +4,9 @@ import re
from typing import Dict, List, Optional, Tuple
import wandb
-from datasets import Dataset, load_dataset # Added Dataset for dummy data
+from datasets import Dataset, load_dataset
from langdetect import LangDetectException, detect
-from pydantic import Field # Added import for Field
+from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
@@ -39,11 +39,11 @@ class IFConfig(BaseEnvConfig):
class InstructionFollowingEnv(BaseEnv):
- env_config_cls = IFConfig # Added env_config_cls for IFConfig
+ env_config_cls = IFConfig
def __init__(
self,
- config: IFConfig, # Changed BaseEnvConfig to IFConfig
+ config: IFConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
@@ -52,14 +52,13 @@ class InstructionFollowingEnv(BaseEnv):
self.percent_correct_buffer = list()
self.eval_metrics = list()
self.rollouts_for_wandb = []
- # self.completion_lengths = [] # Kept from SingleToolCallingEnv, assess utility
@classmethod
def config_init(
self,
- ) -> Tuple[IFConfig, List[APIServerConfig]]: # Changed BaseEnvConfig to IFConfig
+ ) -> Tuple[IFConfig, List[APIServerConfig]]:
# Configuration for the Instruction Following Environment
- env_config = IFConfig( # Changed BaseEnvConfig to IFConfig
+ env_config = IFConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=32,
use_wandb=True,
@@ -130,7 +129,7 @@ class InstructionFollowingEnv(BaseEnv):
- 'prompt': The user's instruction string.
- 'func_name': The string name of the verifier function.
- 'args': A dictionary of arguments for that verifier function.
- """
+ """ # noqa: E501
dataset_name = getattr(self.config, "dataset_name", "allenai/RLVR-IFeval")
dataset_config_name = getattr(
self.config, "dataset_config_name", None
@@ -231,7 +230,7 @@ class InstructionFollowingEnv(BaseEnv):
"original_constraints_for_logging": str(
item.get("constraint", "")
), # For logging, from RLVR-IFeval structure
- "expected_response_for_logging": "", # RLVR-IFeval doesn't seem to have a sample good response directly # noqa: E501
+ "expected_response_for_logging": "",
}
)
@@ -251,7 +250,8 @@ class InstructionFollowingEnv(BaseEnv):
except Exception as e:
# This block is a fallback if the primary dataset loading/processing fails catastrophically.
- # For RLVR-IFeval, a failure here suggests issues with Hugging Face access, dataset integrity, or fundamental code errors. # noqa: E501
+ # For RLVR-IFeval, a failure here suggests issues with Hugging Face access,
+ # dataset integrity, or fundamental code errors.
print(
f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. "
f"Using a DUMMY dataset as fallback."
@@ -316,12 +316,14 @@ class InstructionFollowingEnv(BaseEnv):
# Final check for empty train/test after split, should not happen with logic above if num_items > 0
if len(self.train) == 0 and len(self.test) > 0:
print(
- "Warning: Train set empty after split, test set has data. This is unusual. Swapping."
+ "Warning: Train set empty after split, test set has data. "
+ "This is unusual. Swapping."
)
self.train = self.test # Fallback, though indicates issue
elif len(self.test) == 0 and len(self.train) > 0:
print(
- "Warning: Test set empty after split, train set has data. Using full train set for test as well."
+ "Warning: Test set empty after split, train set has data. "
+ "Using full train set for test as well."
)
self.test = self.train
@@ -339,41 +341,40 @@ class InstructionFollowingEnv(BaseEnv):
# 1. Count and tags
num_think_open = len(re.findall(r"", model_response_text, re.IGNORECASE))
- num_think_close = len(re.findall(r"", model_response_text, re.IGNORECASE))
+ num_think_close = len(
+ re.findall(r"", model_response_text, re.IGNORECASE)
+ )
- # 2. If counts are not exactly one for each, malformed.
if not (num_think_open == 1 and num_think_close == 1):
- # Optionally, add a debug print statement here if needed for logging
- # print(f"DEBUG: Malformed think tags. Open: {num_think_open}, Close: {num_think_close}. Response: '{model_response_text[:200]}...'")
return 0.0
# 3. Find the first occurrence of and
try:
think_open_match = re.search(r"", model_response_text, re.IGNORECASE)
- think_close_match = re.search(r"", model_response_text, re.IGNORECASE)
-
+ think_close_match = re.search(
+ r"", model_response_text, re.IGNORECASE
+ )
+
# These should exist due to the count check, but access .start() and .end() safely
idx_think_open = think_open_match.start()
idx_think_close_start = think_close_match.start()
idx_think_close_end = think_close_match.end()
- except AttributeError:
- # This case should ideally be caught by the count check, but as a fallback:
- # print(f"DEBUG: Could not find start/end of think/close tags despite counts. Response: '{model_response_text[:200]}...'")
+ except AttributeError:
return 0.0
# 4. If appears after , malformed.
if idx_think_open >= idx_think_close_start:
# print(f"DEBUG: tag appears at or after tag. Response: '{model_response_text[:200]}...'")
return 0.0
-
+
# 5. Extract text_to_verify (content after the first )
text_to_verify = model_response_text[idx_think_close_end:].strip()
# 6. Check if text_to_verify itself contains any further or tags.
- if re.search(r"", text_to_verify, re.IGNORECASE) or \
- re.search(r"", text_to_verify, re.IGNORECASE):
- # print(f"DEBUG: Found or tags in the answer part. Answer part: '{text_to_verify[:200]}...'")
+ if re.search(r"", text_to_verify, re.IGNORECASE) or re.search(
+ r"", text_to_verify, re.IGNORECASE
+ ):
return 0.0
# If all checks pass, proceed with verification using text_to_verify
@@ -397,6 +398,7 @@ class InstructionFollowingEnv(BaseEnv):
)
else:
from inspect import signature
+
sig = signature(verifier_func)
valid_params = [p for p in sig.parameters if p != "text"]
filtered_args = {
@@ -446,7 +448,9 @@ class InstructionFollowingEnv(BaseEnv):
func_name = test_item["func_name"]
args_for_verifier = test_item["args"]
- print(f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}...") # DEBUG
+ print(
+ f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}..."
+ ) # DEBUG
messages = [{"role": "system", "content": system_prompt}]
messages.append({"role": "user", "content": instruction_prompt_text})
@@ -455,7 +459,9 @@ class InstructionFollowingEnv(BaseEnv):
messages, add_generation_prompt=True, tokenize=False
)
- print(f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}...") # DEBUG
+ print(
+ f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}..."
+ ) # DEBUG
completion = await self.server.completion(
prompt=prompt_str,
n=1,
@@ -463,7 +469,7 @@ class InstructionFollowingEnv(BaseEnv):
temperature=0.2, # Temperature for eval, can be 0 for deterministic
split="eval",
)
- print(f"DEBUG: Received completion in rollout_and_score_eval.") # DEBUG
+ print("DEBUG: Received completion in rollout_and_score_eval.") # DEBUG
model_response_text = completion.choices[0].text
score_value = await self._get_score_from_verifier(
@@ -481,7 +487,7 @@ class InstructionFollowingEnv(BaseEnv):
self.eval_metrics.append(("eval/percent_correct", 0.0))
return
- print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG
+ print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG
eval_tasks = []
for test_item_dict in self.test: # self.test contains dicts after setup
eval_tasks.append(self.rollout_and_score_eval(test_item_dict))
@@ -501,7 +507,7 @@ class InstructionFollowingEnv(BaseEnv):
) -> Tuple[Optional[ScoredDataGroup], List]:
# item = (prompt_messages_tuple, answer_info_dict)
# answer_info_dict = {"func_name": ..., "args": ...}
- print(f"DEBUG: Entering collect_trajectories. Item: {item}") # DEBUG
+ print(f"DEBUG: Entering collect_trajectories. Item: {str(item)}") # DEBUG
prompt_messages_list = [dict(msg_fset) for msg_fset in item[0]]
answer_info = item[1]
@@ -509,7 +515,9 @@ class InstructionFollowingEnv(BaseEnv):
prompt_messages_list, add_generation_prompt=True, tokenize=False
)
- print(f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}...") # DEBUG
+ print(
+ f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}..."
+ ) # DEBUG
try:
completions = await self.server.completion(
prompt=prompt_str,
@@ -517,11 +525,15 @@ class InstructionFollowingEnv(BaseEnv):
max_tokens=self.config.max_token_length,
temperature=0.8, # Temperature for diverse responses during training rollouts
)
- print(f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories.") # DEBUG
+ print(
+ f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories."
+ ) # DEBUG
except Exception as e:
- print(f"ERROR: Exception during self.server.completion in collect_trajectories: {e}") # DEBUG
+ print(
+ f"ERROR: Exception during self.server.completion in collect_trajectories: {e}"
+ ) # DEBUG
# Depending on the desired behavior, you might want to return None or raise the exception
- return None, []
+ return None, []
to_score_list = []
for choice in completions.choices:
@@ -534,11 +546,15 @@ class InstructionFollowingEnv(BaseEnv):
if not to_score_list:
return None, []
- print(f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories.") # DEBUG
+ print(
+ f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories."
+ ) # DEBUG
scored_data = await self.score(to_score_list)
to_backlog = [] # Backlog not currently used but part of signature
- print(f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}") # DEBUG
+ print(
+ f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}"
+ ) # DEBUG
return scored_data, to_backlog
def save_checkpoint(self, step, data=None):
@@ -817,7 +833,8 @@ def validate_word_constraint(text: str, N: int, quantifier: str) -> bool:
def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool:
# Basic sentence splitting, might need more robust NLP for complex cases
sentences = re.split(
- r"(?