update some dataset stuff to use allenai's

This commit is contained in:
teknium1 2025-05-14 18:39:31 -07:00
parent 881af55f9a
commit bcc38567ca

View file

@ -45,21 +45,20 @@ class InstructionFollowingEnv(BaseEnv):
def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
# Configuration for the Instruction Following Environment
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", # Or other suitable tokenizer
group_size=32, # Number of rollouts per group
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
use_wandb=True,
rollout_server_url="http://localhost:8000", # Assuming same rollout server
total_steps=2000,
batch_size=1024, # Samples per training batch
rollout_server_url="http://localhost:8000",
total_steps=500,
batch_size=1024,
steps_per_eval=20,
max_token_length=1024 * 2, # Max token length for model generation, adjust as needed for IF
max_token_length=1024 * 16,
inference_weight=1.0,
wandb_name="instruction_following_rlvr", # Specific WandB project name
wandb_name="instruction_following_rlvr_ifeval", # Specific WandB project name
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
# Add any specific dataset configurations here if needed
# dataset_name="allenai/ifeval",
# dataset_config_name="default", # Or specific config for the dataset
dataset_name="allenai/RLVR-IFeval", # Default dataset
dataset_config_name=None, # RLVR-IFeval doesn't have a specific config name, uses 'default'
)
# Server configurations can be similar to SingleToolCallingEnv or adjusted
server_configs = [
@ -113,34 +112,22 @@ class InstructionFollowingEnv(BaseEnv):
async def setup(self):
"""
Load and preprocess the dataset for instruction following.
This method expects each data item from the loaded dataset to have at least:
- 'prompt': The instruction text for the LLM (string).
- 'func_name': The string name of the verifier function from IF_FUNCTIONS_MAP.
- 'args_json': A JSON string representing the dictionary of arguments for the verifier function.
This method is specifically tailored to process 'allenai/RLVR-IFeval' dataset structure.
Each item from RLVR-IFeval is expected to have:
- 'messages': A list of dictionaries, e.g., [{'role': 'user', 'content': 'instruction...'}]
- 'ground_truth': A JSON string containing 'func_name' and arguments for the verifier.
Example item structure expected from the dataset loader (after any initial parsing):
{
"prompt": "Include the keywords 'apple' and 'banana' in your response.",
"func_name": "verify_keywords",
"args_json": "{\\"keyword_list\\": [\\"apple\\", \\"banana\\"]}", // Note: escaped for Python string
"original_constraints_for_logging": "Include keywords {apple}, {banana} in your response", // Optional
"expected_response_for_logging": "An apple and a banana are fruits." // Optional
}
If your raw dataset (e.g., from Hugging Face) has a natural language constraint string,
you need to implement a parsing step (either before this environment or at the beginning
of this setup method) to convert that string into 'func_name' and 'args_json'.
The verifier functions and IF_FUNCTIONS_MAP included in this file define the available functions
and their expected argument names.
The method will parse these to produce items for the environment with:
- 'prompt': The user's instruction string.
- 'func_name': The string name of the verifier function.
- 'args': A dictionary of arguments for that verifier function.
"""
dataset_name = getattr(self.config, "dataset_name", "allenai/ifeval") # Example dataset
dataset_config_name = getattr(self.config, "dataset_config_name", None)
dataset_name = getattr(self.config, "dataset_name", "allenai/RLVR-IFeval")
dataset_config_name = getattr(self.config, "dataset_config_name", None) # Default is None, RLVR-IFeval has no sub-config
processed_items = []
try:
# Attempt to load the dataset specified in the config
# This section assumes 'dataset_name' provides items with 'prompt', 'func_name', and 'args_json'
print(f"Attempting to load dataset: {dataset_name}, config: {dataset_config_name}")
print(f"Attempting to load dataset: {dataset_name}, config: {dataset_config_name if dataset_config_name else 'default'}")
if dataset_config_name:
full_dataset_raw = load_dataset(dataset_name, dataset_config_name, split="train", trust_remote_code=True)
else:
@ -148,51 +135,64 @@ class InstructionFollowingEnv(BaseEnv):
print(f"Successfully loaded raw dataset. Number of items: {len(full_dataset_raw)}")
for i, item in enumerate(full_dataset_raw):
prompt_text = item.get("prompt")
func_name_from_item = item.get("func_name")
args_json_from_item = item.get("args_json") # Expecting a JSON string
if not prompt_text or not func_name_from_item or args_json_from_item is None: # Check explicitly for None
print(f"Warning: Item {i} missing 'prompt', 'func_name', or 'args_json'. Skipping. Item: {item}")
# Extract prompt from 'messages' field
item_messages = item.get("messages")
if not item_messages or not isinstance(item_messages, list) or len(item_messages) == 0:
print(f"Warning: Item {i} has invalid or empty 'messages' field. Skipping. Item: {item}")
continue
if func_name_from_item not in IF_FUNCTIONS_MAP:
print(f"Warning: func_name '{func_name_from_item}' in item {i} not in IF_FUNCTIONS_MAP. Skipping. Prompt: {prompt_text[:50]}...")
# Assuming the relevant prompt is the content of the first message in the list
# (or last, if multiple user messages were possible, but IFEval is typically single user instruction)
prompt_text = item_messages[0].get("content")
if not prompt_text:
print(f"Warning: Item {i} '{item_messages[0]}' has no content. Skipping.")
continue
# Get the ground_truth JSON string
ground_truth_json_str = item.get("ground_truth")
if not ground_truth_json_str or not isinstance(ground_truth_json_str, str):
print(f"Warning: Item {i} missing or has invalid 'ground_truth' string. Skipping. Prompt: {prompt_text[:50]}...")
continue
try:
args_dict = json.loads(args_json_from_item)
if not isinstance(args_dict, dict):
# Allow empty string for args_json to represent empty dict, but json.loads('') fails
# However, json.loads('{}') is fine. Assume args_json is valid JSON if not empty.
raise ValueError("Parsed args_json is not a dictionary.")
except json.JSONDecodeError as e:
print(f"Warning: Could not parse 'args_json' for item {i}. Error: {e}. Args JSON: '{args_json_from_item}'. Prompt: {prompt_text[:50]}... Skipping.")
parsed_gt = json.loads(ground_truth_json_str)
if not isinstance(parsed_gt, dict):
raise ValueError("Parsed ground_truth is not a dictionary.")
except (json.JSONDecodeError, ValueError) as e:
print(f"Warning: Could not parse 'ground_truth' JSON for item {i}. Error: {e}. GT String: '{ground_truth_json_str}'. Prompt: {prompt_text[:50]}... Skipping.")
continue
except ValueError as e: # Catches the "not a dictionary" error
print(f"Warning: Parsed 'args_json' was not a dictionary for item {i}. Error: {e}. Args JSON: '{args_json_from_item}'. Prompt: {prompt_text[:50]}... Skipping.")
func_name_from_gt = parsed_gt.get("func_name")
if not func_name_from_gt:
print(f"Warning: Item {i} parsed 'ground_truth' has no 'func_name'. GT: {parsed_gt}. Prompt: {prompt_text[:50]}... Skipping.")
continue
if func_name_from_gt not in IF_FUNCTIONS_MAP:
print(f"Warning: func_name '{func_name_from_gt}' in item {i} not in IF_FUNCTIONS_MAP. Prompt: {prompt_text[:50]}... Skipping.")
continue
# Prepare args for the verifier function: remove func_name and keep others.
# Verifier functions will only use args they expect.
args_dict = {k: v for k, v in parsed_gt.items() if k != "func_name" and v is not None}
processed_items.append({
"prompt": prompt_text,
"func_name": func_name_from_item,
"args": args_dict, # Parsed dictionary of arguments
"original_constraints_for_logging": item.get("original_constraints_for_logging", str(item.get("constraints", ""))), # For logging
"expected_response_for_logging": item.get("expected_response_for_logging", str(item.get("response", ""))) # For logging
"func_name": func_name_from_gt,
"args": args_dict,
"original_constraints_for_logging": str(item.get("constraint", "")), # For logging, from RLVR-IFeval structure
"expected_response_for_logging": "" # RLVR-IFeval doesn't seem to have a sample good response directly
})
if not processed_items:
print("Warning: No items successfully processed from the dataset. Check dataset format/content or parsing logic if any.")
# Fallback to dummy data if processing yields nothing, to allow environment to initialize.
# This indicates a problem with the primary dataset source or its assumed structure.
raise ValueError("Dataset processing resulted in no valid items. Cannot proceed without data or a valid dummy fallback.")
print("Warning: No items successfully processed from the dataset. Check dataset format/content or parsing logic.")
raise ValueError("Dataset processing resulted in no valid items for RLVR-IFeval. Cannot proceed without data.")
full_dataset = Dataset.from_list(processed_items)
print(f"Successfully processed {len(full_dataset)} items from dataset.")
print(f"Successfully processed {len(full_dataset)} items from dataset '{dataset_name}'.")
except Exception as e:
print(f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. Using a small DUMMY dataset as a fallback.")
# Fallback to a minimal dummy dataset with the expected structure for 'args' already parsed
# This block is a fallback if the primary dataset loading/processing fails catastrophically.
# For RLVR-IFeval, a failure here suggests issues with Hugging Face access, dataset integrity, or fundamental code errors.
print(f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. Using a DUMMY dataset as fallback.")
dummy_data_for_fallback = [
{
"prompt": "Dummy Instruction 1: Ensure your response contains the word 'example'.",
@ -206,11 +206,11 @@ class InstructionFollowingEnv(BaseEnv):
"func_name": "validate_json_format",
"args": {},
"original_constraints_for_logging": "Output valid JSON.",
"expected_response_for_logging": "{\\\"data\\\": \\\"test\\\"}" # Corrected: Escaped for Python string
"expected_response_for_logging": "{\\\"data\\\": \\\"test\\\"}"
}
]
full_dataset = Dataset.from_list(dummy_data_for_fallback)
print(f"Initialized with DUMMY dataset of {len(full_dataset)} items.")
print(f"Initialized with DUMMY dataset of {len(full_dataset)} items due to previous errors.")
full_dataset = full_dataset.shuffle(seed=42)
@ -457,8 +457,8 @@ class InstructionFollowingEnv(BaseEnv):
return scores_container # Avoid division by zero, or if all empty
max_allowed_length = self.config.max_token_length
# Threshold can be adjusted, e.g., 50% of max_token_length
length_threshold = max_allowed_length * 0.5
# Threshold can be adjusted, e.g., 75% of max_token_length
length_threshold = max_allowed_length * 0.75
penalized_scores = []
for i, length in enumerate(token_lengths):