mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
update some dataset stuff to use allenai's
This commit is contained in:
parent
881af55f9a
commit
bcc38567ca
1 changed files with 66 additions and 66 deletions
|
|
@ -45,21 +45,20 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
# Configuration for the Instruction Following Environment
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", # Or other suitable tokenizer
|
||||
group_size=32, # Number of rollouts per group
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=16,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000", # Assuming same rollout server
|
||||
total_steps=2000,
|
||||
batch_size=1024, # Samples per training batch
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=500,
|
||||
batch_size=1024,
|
||||
steps_per_eval=20,
|
||||
max_token_length=1024 * 2, # Max token length for model generation, adjust as needed for IF
|
||||
max_token_length=1024 * 16,
|
||||
inference_weight=1.0,
|
||||
wandb_name="instruction_following_rlvr", # Specific WandB project name
|
||||
wandb_name="instruction_following_rlvr_ifeval", # Specific WandB project name
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
# Add any specific dataset configurations here if needed
|
||||
# dataset_name="allenai/ifeval",
|
||||
# dataset_config_name="default", # Or specific config for the dataset
|
||||
dataset_name="allenai/RLVR-IFeval", # Default dataset
|
||||
dataset_config_name=None, # RLVR-IFeval doesn't have a specific config name, uses 'default'
|
||||
)
|
||||
# Server configurations can be similar to SingleToolCallingEnv or adjusted
|
||||
server_configs = [
|
||||
|
|
@ -113,34 +112,22 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
async def setup(self):
|
||||
"""
|
||||
Load and preprocess the dataset for instruction following.
|
||||
This method expects each data item from the loaded dataset to have at least:
|
||||
- 'prompt': The instruction text for the LLM (string).
|
||||
- 'func_name': The string name of the verifier function from IF_FUNCTIONS_MAP.
|
||||
- 'args_json': A JSON string representing the dictionary of arguments for the verifier function.
|
||||
This method is specifically tailored to process 'allenai/RLVR-IFeval' dataset structure.
|
||||
Each item from RLVR-IFeval is expected to have:
|
||||
- 'messages': A list of dictionaries, e.g., [{'role': 'user', 'content': 'instruction...'}]
|
||||
- 'ground_truth': A JSON string containing 'func_name' and arguments for the verifier.
|
||||
|
||||
Example item structure expected from the dataset loader (after any initial parsing):
|
||||
{
|
||||
"prompt": "Include the keywords 'apple' and 'banana' in your response.",
|
||||
"func_name": "verify_keywords",
|
||||
"args_json": "{\\"keyword_list\\": [\\"apple\\", \\"banana\\"]}", // Note: escaped for Python string
|
||||
"original_constraints_for_logging": "Include keywords {apple}, {banana} in your response", // Optional
|
||||
"expected_response_for_logging": "An apple and a banana are fruits." // Optional
|
||||
}
|
||||
|
||||
If your raw dataset (e.g., from Hugging Face) has a natural language constraint string,
|
||||
you need to implement a parsing step (either before this environment or at the beginning
|
||||
of this setup method) to convert that string into 'func_name' and 'args_json'.
|
||||
The verifier functions and IF_FUNCTIONS_MAP included in this file define the available functions
|
||||
and their expected argument names.
|
||||
The method will parse these to produce items for the environment with:
|
||||
- 'prompt': The user's instruction string.
|
||||
- 'func_name': The string name of the verifier function.
|
||||
- 'args': A dictionary of arguments for that verifier function.
|
||||
"""
|
||||
dataset_name = getattr(self.config, "dataset_name", "allenai/ifeval") # Example dataset
|
||||
dataset_config_name = getattr(self.config, "dataset_config_name", None)
|
||||
dataset_name = getattr(self.config, "dataset_name", "allenai/RLVR-IFeval")
|
||||
dataset_config_name = getattr(self.config, "dataset_config_name", None) # Default is None, RLVR-IFeval has no sub-config
|
||||
|
||||
processed_items = []
|
||||
try:
|
||||
# Attempt to load the dataset specified in the config
|
||||
# This section assumes 'dataset_name' provides items with 'prompt', 'func_name', and 'args_json'
|
||||
print(f"Attempting to load dataset: {dataset_name}, config: {dataset_config_name}")
|
||||
print(f"Attempting to load dataset: {dataset_name}, config: {dataset_config_name if dataset_config_name else 'default'}")
|
||||
if dataset_config_name:
|
||||
full_dataset_raw = load_dataset(dataset_name, dataset_config_name, split="train", trust_remote_code=True)
|
||||
else:
|
||||
|
|
@ -148,51 +135,64 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
print(f"Successfully loaded raw dataset. Number of items: {len(full_dataset_raw)}")
|
||||
|
||||
for i, item in enumerate(full_dataset_raw):
|
||||
prompt_text = item.get("prompt")
|
||||
func_name_from_item = item.get("func_name")
|
||||
args_json_from_item = item.get("args_json") # Expecting a JSON string
|
||||
|
||||
if not prompt_text or not func_name_from_item or args_json_from_item is None: # Check explicitly for None
|
||||
print(f"Warning: Item {i} missing 'prompt', 'func_name', or 'args_json'. Skipping. Item: {item}")
|
||||
# Extract prompt from 'messages' field
|
||||
item_messages = item.get("messages")
|
||||
if not item_messages or not isinstance(item_messages, list) or len(item_messages) == 0:
|
||||
print(f"Warning: Item {i} has invalid or empty 'messages' field. Skipping. Item: {item}")
|
||||
continue
|
||||
|
||||
if func_name_from_item not in IF_FUNCTIONS_MAP:
|
||||
print(f"Warning: func_name '{func_name_from_item}' in item {i} not in IF_FUNCTIONS_MAP. Skipping. Prompt: {prompt_text[:50]}...")
|
||||
# Assuming the relevant prompt is the content of the first message in the list
|
||||
# (or last, if multiple user messages were possible, but IFEval is typically single user instruction)
|
||||
prompt_text = item_messages[0].get("content")
|
||||
if not prompt_text:
|
||||
print(f"Warning: Item {i} '{item_messages[0]}' has no content. Skipping.")
|
||||
continue
|
||||
|
||||
# Get the ground_truth JSON string
|
||||
ground_truth_json_str = item.get("ground_truth")
|
||||
if not ground_truth_json_str or not isinstance(ground_truth_json_str, str):
|
||||
print(f"Warning: Item {i} missing or has invalid 'ground_truth' string. Skipping. Prompt: {prompt_text[:50]}...")
|
||||
continue
|
||||
|
||||
try:
|
||||
args_dict = json.loads(args_json_from_item)
|
||||
if not isinstance(args_dict, dict):
|
||||
# Allow empty string for args_json to represent empty dict, but json.loads('') fails
|
||||
# However, json.loads('{}') is fine. Assume args_json is valid JSON if not empty.
|
||||
raise ValueError("Parsed args_json is not a dictionary.")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Warning: Could not parse 'args_json' for item {i}. Error: {e}. Args JSON: '{args_json_from_item}'. Prompt: {prompt_text[:50]}... Skipping.")
|
||||
parsed_gt = json.loads(ground_truth_json_str)
|
||||
if not isinstance(parsed_gt, dict):
|
||||
raise ValueError("Parsed ground_truth is not a dictionary.")
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f"Warning: Could not parse 'ground_truth' JSON for item {i}. Error: {e}. GT String: '{ground_truth_json_str}'. Prompt: {prompt_text[:50]}... Skipping.")
|
||||
continue
|
||||
except ValueError as e: # Catches the "not a dictionary" error
|
||||
print(f"Warning: Parsed 'args_json' was not a dictionary for item {i}. Error: {e}. Args JSON: '{args_json_from_item}'. Prompt: {prompt_text[:50]}... Skipping.")
|
||||
|
||||
func_name_from_gt = parsed_gt.get("func_name")
|
||||
if not func_name_from_gt:
|
||||
print(f"Warning: Item {i} parsed 'ground_truth' has no 'func_name'. GT: {parsed_gt}. Prompt: {prompt_text[:50]}... Skipping.")
|
||||
continue
|
||||
|
||||
if func_name_from_gt not in IF_FUNCTIONS_MAP:
|
||||
print(f"Warning: func_name '{func_name_from_gt}' in item {i} not in IF_FUNCTIONS_MAP. Prompt: {prompt_text[:50]}... Skipping.")
|
||||
continue
|
||||
|
||||
# Prepare args for the verifier function: remove func_name and keep others.
|
||||
# Verifier functions will only use args they expect.
|
||||
args_dict = {k: v for k, v in parsed_gt.items() if k != "func_name" and v is not None}
|
||||
|
||||
processed_items.append({
|
||||
"prompt": prompt_text,
|
||||
"func_name": func_name_from_item,
|
||||
"args": args_dict, # Parsed dictionary of arguments
|
||||
"original_constraints_for_logging": item.get("original_constraints_for_logging", str(item.get("constraints", ""))), # For logging
|
||||
"expected_response_for_logging": item.get("expected_response_for_logging", str(item.get("response", ""))) # For logging
|
||||
"func_name": func_name_from_gt,
|
||||
"args": args_dict,
|
||||
"original_constraints_for_logging": str(item.get("constraint", "")), # For logging, from RLVR-IFeval structure
|
||||
"expected_response_for_logging": "" # RLVR-IFeval doesn't seem to have a sample good response directly
|
||||
})
|
||||
|
||||
if not processed_items:
|
||||
print("Warning: No items successfully processed from the dataset. Check dataset format/content or parsing logic if any.")
|
||||
# Fallback to dummy data if processing yields nothing, to allow environment to initialize.
|
||||
# This indicates a problem with the primary dataset source or its assumed structure.
|
||||
raise ValueError("Dataset processing resulted in no valid items. Cannot proceed without data or a valid dummy fallback.")
|
||||
print("Warning: No items successfully processed from the dataset. Check dataset format/content or parsing logic.")
|
||||
raise ValueError("Dataset processing resulted in no valid items for RLVR-IFeval. Cannot proceed without data.")
|
||||
|
||||
full_dataset = Dataset.from_list(processed_items)
|
||||
print(f"Successfully processed {len(full_dataset)} items from dataset.")
|
||||
print(f"Successfully processed {len(full_dataset)} items from dataset '{dataset_name}'.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. Using a small DUMMY dataset as a fallback.")
|
||||
# Fallback to a minimal dummy dataset with the expected structure for 'args' already parsed
|
||||
# This block is a fallback if the primary dataset loading/processing fails catastrophically.
|
||||
# For RLVR-IFeval, a failure here suggests issues with Hugging Face access, dataset integrity, or fundamental code errors.
|
||||
print(f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. Using a DUMMY dataset as fallback.")
|
||||
dummy_data_for_fallback = [
|
||||
{
|
||||
"prompt": "Dummy Instruction 1: Ensure your response contains the word 'example'.",
|
||||
|
|
@ -206,11 +206,11 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
"func_name": "validate_json_format",
|
||||
"args": {},
|
||||
"original_constraints_for_logging": "Output valid JSON.",
|
||||
"expected_response_for_logging": "{\\\"data\\\": \\\"test\\\"}" # Corrected: Escaped for Python string
|
||||
"expected_response_for_logging": "{\\\"data\\\": \\\"test\\\"}"
|
||||
}
|
||||
]
|
||||
full_dataset = Dataset.from_list(dummy_data_for_fallback)
|
||||
print(f"Initialized with DUMMY dataset of {len(full_dataset)} items.")
|
||||
print(f"Initialized with DUMMY dataset of {len(full_dataset)} items due to previous errors.")
|
||||
|
||||
full_dataset = full_dataset.shuffle(seed=42)
|
||||
|
||||
|
|
@ -457,8 +457,8 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
return scores_container # Avoid division by zero, or if all empty
|
||||
|
||||
max_allowed_length = self.config.max_token_length
|
||||
# Threshold can be adjusted, e.g., 50% of max_token_length
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
# Threshold can be adjusted, e.g., 75% of max_token_length
|
||||
length_threshold = max_allowed_length * 0.75
|
||||
|
||||
penalized_scores = []
|
||||
for i, length in enumerate(token_lengths):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue