some cleanup for final merge

This commit is contained in:
teknium1 2025-05-16 19:24:50 -07:00
parent daa6f0ff18
commit 287bbcd356
2 changed files with 58 additions and 39 deletions

View file

@ -4,9 +4,9 @@ import re
from typing import Dict, List, Optional, Tuple
import wandb
from datasets import Dataset, load_dataset # Added Dataset for dummy data
from datasets import Dataset, load_dataset
from langdetect import LangDetectException, detect
from pydantic import Field # Added import for Field
from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
@ -39,11 +39,11 @@ class IFConfig(BaseEnvConfig):
class InstructionFollowingEnv(BaseEnv):
env_config_cls = IFConfig # Added env_config_cls for IFConfig
env_config_cls = IFConfig
def __init__(
self,
config: IFConfig, # Changed BaseEnvConfig to IFConfig
config: IFConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
@ -52,14 +52,13 @@ class InstructionFollowingEnv(BaseEnv):
self.percent_correct_buffer = list()
self.eval_metrics = list()
self.rollouts_for_wandb = []
# self.completion_lengths = [] # Kept from SingleToolCallingEnv, assess utility
@classmethod
def config_init(
self,
) -> Tuple[IFConfig, List[APIServerConfig]]: # Changed BaseEnvConfig to IFConfig
) -> Tuple[IFConfig, List[APIServerConfig]]:
# Configuration for the Instruction Following Environment
env_config = IFConfig( # Changed BaseEnvConfig to IFConfig
env_config = IFConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=32,
use_wandb=True,
@ -130,7 +129,7 @@ class InstructionFollowingEnv(BaseEnv):
- 'prompt': The user's instruction string.
- 'func_name': The string name of the verifier function.
- 'args': A dictionary of arguments for that verifier function.
"""
""" # noqa: E501
dataset_name = getattr(self.config, "dataset_name", "allenai/RLVR-IFeval")
dataset_config_name = getattr(
self.config, "dataset_config_name", None
@ -231,7 +230,7 @@ class InstructionFollowingEnv(BaseEnv):
"original_constraints_for_logging": str(
item.get("constraint", "")
), # For logging, from RLVR-IFeval structure
"expected_response_for_logging": "", # RLVR-IFeval doesn't seem to have a sample good response directly # noqa: E501
"expected_response_for_logging": "",
}
)
@ -251,7 +250,8 @@ class InstructionFollowingEnv(BaseEnv):
except Exception as e:
# This block is a fallback if the primary dataset loading/processing fails catastrophically.
# For RLVR-IFeval, a failure here suggests issues with Hugging Face access, dataset integrity, or fundamental code errors. # noqa: E501
# For RLVR-IFeval, a failure here suggests issues with Hugging Face access,
# dataset integrity, or fundamental code errors.
print(
f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. "
f"Using a DUMMY dataset as fallback."
@ -316,12 +316,14 @@ class InstructionFollowingEnv(BaseEnv):
# Final check for empty train/test after split, should not happen with logic above if num_items > 0
if len(self.train) == 0 and len(self.test) > 0:
print(
"Warning: Train set empty after split, test set has data. This is unusual. Swapping."
"Warning: Train set empty after split, test set has data. "
"This is unusual. Swapping."
)
self.train = self.test # Fallback, though indicates issue
elif len(self.test) == 0 and len(self.train) > 0:
print(
"Warning: Test set empty after split, train set has data. Using full train set for test as well."
"Warning: Test set empty after split, train set has data. "
"Using full train set for test as well."
)
self.test = self.train
@ -339,41 +341,40 @@ class InstructionFollowingEnv(BaseEnv):
# 1. Count <think> and </think> tags
num_think_open = len(re.findall(r"<think>", model_response_text, re.IGNORECASE))
num_think_close = len(re.findall(r"</think>", model_response_text, re.IGNORECASE))
num_think_close = len(
re.findall(r"</think>", model_response_text, re.IGNORECASE)
)
# 2. If counts are not exactly one for each, malformed.
if not (num_think_open == 1 and num_think_close == 1):
# Optionally, add a debug print statement here if needed for logging
# print(f"DEBUG: Malformed think tags. Open: {num_think_open}, Close: {num_think_close}. Response: '{model_response_text[:200]}...'")
return 0.0
# 3. Find the first occurrence of <think> and </think>
try:
think_open_match = re.search(r"<think>", model_response_text, re.IGNORECASE)
think_close_match = re.search(r"</think>", model_response_text, re.IGNORECASE)
think_close_match = re.search(
r"</think>", model_response_text, re.IGNORECASE
)
# These should exist due to the count check, but access .start() and .end() safely
idx_think_open = think_open_match.start()
idx_think_close_start = think_close_match.start()
idx_think_close_end = think_close_match.end()
except AttributeError:
# This case should ideally be caught by the count check, but as a fallback:
# print(f"DEBUG: Could not find start/end of think/close tags despite counts. Response: '{model_response_text[:200]}...'")
except AttributeError:
return 0.0
# 4. If <think> appears after </think>, malformed.
if idx_think_open >= idx_think_close_start:
# print(f"DEBUG: <think> tag appears at or after </think> tag. Response: '{model_response_text[:200]}...'")
return 0.0
# 5. Extract text_to_verify (content after the first </think>)
text_to_verify = model_response_text[idx_think_close_end:].strip()
# 6. Check if text_to_verify itself contains any further <think> or </think> tags.
if re.search(r"<think>", text_to_verify, re.IGNORECASE) or \
re.search(r"</think>", text_to_verify, re.IGNORECASE):
# print(f"DEBUG: Found <think> or </think> tags in the answer part. Answer part: '{text_to_verify[:200]}...'")
if re.search(r"<think>", text_to_verify, re.IGNORECASE) or re.search(
r"</think>", text_to_verify, re.IGNORECASE
):
return 0.0
# If all checks pass, proceed with verification using text_to_verify
@ -397,6 +398,7 @@ class InstructionFollowingEnv(BaseEnv):
)
else:
from inspect import signature
sig = signature(verifier_func)
valid_params = [p for p in sig.parameters if p != "text"]
filtered_args = {
@ -446,7 +448,9 @@ class InstructionFollowingEnv(BaseEnv):
func_name = test_item["func_name"]
args_for_verifier = test_item["args"]
print(f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}...") # DEBUG
print(
f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}..."
) # DEBUG
messages = [{"role": "system", "content": system_prompt}]
messages.append({"role": "user", "content": instruction_prompt_text})
@ -455,7 +459,9 @@ class InstructionFollowingEnv(BaseEnv):
messages, add_generation_prompt=True, tokenize=False
)
print(f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}...") # DEBUG
print(
f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}..."
) # DEBUG
completion = await self.server.completion(
prompt=prompt_str,
n=1,
@ -463,7 +469,7 @@ class InstructionFollowingEnv(BaseEnv):
temperature=0.2, # Temperature for eval, can be 0 for deterministic
split="eval",
)
print(f"DEBUG: Received completion in rollout_and_score_eval.") # DEBUG
print("DEBUG: Received completion in rollout_and_score_eval.") # DEBUG
model_response_text = completion.choices[0].text
score_value = await self._get_score_from_verifier(
@ -481,7 +487,7 @@ class InstructionFollowingEnv(BaseEnv):
self.eval_metrics.append(("eval/percent_correct", 0.0))
return
print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG
print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG
eval_tasks = []
for test_item_dict in self.test: # self.test contains dicts after setup
eval_tasks.append(self.rollout_and_score_eval(test_item_dict))
@ -501,7 +507,7 @@ class InstructionFollowingEnv(BaseEnv):
) -> Tuple[Optional[ScoredDataGroup], List]:
# item = (prompt_messages_tuple, answer_info_dict)
# answer_info_dict = {"func_name": ..., "args": ...}
print(f"DEBUG: Entering collect_trajectories. Item: {item}") # DEBUG
print(f"DEBUG: Entering collect_trajectories. Item: {str(item)}") # DEBUG
prompt_messages_list = [dict(msg_fset) for msg_fset in item[0]]
answer_info = item[1]
@ -509,7 +515,9 @@ class InstructionFollowingEnv(BaseEnv):
prompt_messages_list, add_generation_prompt=True, tokenize=False
)
print(f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}...") # DEBUG
print(
f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}..."
) # DEBUG
try:
completions = await self.server.completion(
prompt=prompt_str,
@ -517,11 +525,15 @@ class InstructionFollowingEnv(BaseEnv):
max_tokens=self.config.max_token_length,
temperature=0.8, # Temperature for diverse responses during training rollouts
)
print(f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories.") # DEBUG
print(
f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories."
) # DEBUG
except Exception as e:
print(f"ERROR: Exception during self.server.completion in collect_trajectories: {e}") # DEBUG
print(
f"ERROR: Exception during self.server.completion in collect_trajectories: {e}"
) # DEBUG
# Depending on the desired behavior, you might want to return None or raise the exception
return None, []
return None, []
to_score_list = []
for choice in completions.choices:
@ -534,11 +546,15 @@ class InstructionFollowingEnv(BaseEnv):
if not to_score_list:
return None, []
print(f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories.") # DEBUG
print(
f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories."
) # DEBUG
scored_data = await self.score(to_score_list)
to_backlog = [] # Backlog not currently used but part of signature
print(f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}") # DEBUG
print(
f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}"
) # DEBUG
return scored_data, to_backlog
def save_checkpoint(self, step, data=None):
@ -817,7 +833,8 @@ def validate_word_constraint(text: str, N: int, quantifier: str) -> bool:
def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool:
# Basic sentence splitting, might need more robust NLP for complex cases
sentences = re.split(
r"(?<![a-zA-Z0-9_]\.[a-zA-Z0-9_]\.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)\s", text.strip()
r"(?<![a-zA-Z0-9_]\.[a-zA-Z0-9_]\.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)\s",
text.strip(),
)
# Filter out empty strings that might result from splitting
sentences = [s for s in sentences if s.strip()]