mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
some cleanup for final merge
This commit is contained in:
parent
daa6f0ff18
commit
287bbcd356
2 changed files with 58 additions and 39 deletions
|
|
@ -4,9 +4,9 @@ import re
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import wandb
|
||||
from datasets import Dataset, load_dataset # Added Dataset for dummy data
|
||||
from datasets import Dataset, load_dataset
|
||||
from langdetect import LangDetectException, detect
|
||||
from pydantic import Field # Added import for Field
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
|
|
@ -39,11 +39,11 @@ class IFConfig(BaseEnvConfig):
|
|||
|
||||
|
||||
class InstructionFollowingEnv(BaseEnv):
|
||||
env_config_cls = IFConfig # Added env_config_cls for IFConfig
|
||||
env_config_cls = IFConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: IFConfig, # Changed BaseEnvConfig to IFConfig
|
||||
config: IFConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
|
|
@ -52,14 +52,13 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
self.rollouts_for_wandb = []
|
||||
# self.completion_lengths = [] # Kept from SingleToolCallingEnv, assess utility
|
||||
|
||||
@classmethod
|
||||
def config_init(
|
||||
self,
|
||||
) -> Tuple[IFConfig, List[APIServerConfig]]: # Changed BaseEnvConfig to IFConfig
|
||||
) -> Tuple[IFConfig, List[APIServerConfig]]:
|
||||
# Configuration for the Instruction Following Environment
|
||||
env_config = IFConfig( # Changed BaseEnvConfig to IFConfig
|
||||
env_config = IFConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=32,
|
||||
use_wandb=True,
|
||||
|
|
@ -130,7 +129,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
- 'prompt': The user's instruction string.
|
||||
- 'func_name': The string name of the verifier function.
|
||||
- 'args': A dictionary of arguments for that verifier function.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
dataset_name = getattr(self.config, "dataset_name", "allenai/RLVR-IFeval")
|
||||
dataset_config_name = getattr(
|
||||
self.config, "dataset_config_name", None
|
||||
|
|
@ -231,7 +230,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
"original_constraints_for_logging": str(
|
||||
item.get("constraint", "")
|
||||
), # For logging, from RLVR-IFeval structure
|
||||
"expected_response_for_logging": "", # RLVR-IFeval doesn't seem to have a sample good response directly # noqa: E501
|
||||
"expected_response_for_logging": "",
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -251,7 +250,8 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
|
||||
except Exception as e:
|
||||
# This block is a fallback if the primary dataset loading/processing fails catastrophically.
|
||||
# For RLVR-IFeval, a failure here suggests issues with Hugging Face access, dataset integrity, or fundamental code errors. # noqa: E501
|
||||
# For RLVR-IFeval, a failure here suggests issues with Hugging Face access,
|
||||
# dataset integrity, or fundamental code errors.
|
||||
print(
|
||||
f"CRITICAL: Failed to load or process primary dataset '{dataset_name}': {e}. "
|
||||
f"Using a DUMMY dataset as fallback."
|
||||
|
|
@ -316,12 +316,14 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
# Final check for empty train/test after split, should not happen with logic above if num_items > 0
|
||||
if len(self.train) == 0 and len(self.test) > 0:
|
||||
print(
|
||||
"Warning: Train set empty after split, test set has data. This is unusual. Swapping."
|
||||
"Warning: Train set empty after split, test set has data. "
|
||||
"This is unusual. Swapping."
|
||||
)
|
||||
self.train = self.test # Fallback, though indicates issue
|
||||
elif len(self.test) == 0 and len(self.train) > 0:
|
||||
print(
|
||||
"Warning: Test set empty after split, train set has data. Using full train set for test as well."
|
||||
"Warning: Test set empty after split, train set has data. "
|
||||
"Using full train set for test as well."
|
||||
)
|
||||
self.test = self.train
|
||||
|
||||
|
|
@ -339,41 +341,40 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
|
||||
# 1. Count <think> and </think> tags
|
||||
num_think_open = len(re.findall(r"<think>", model_response_text, re.IGNORECASE))
|
||||
num_think_close = len(re.findall(r"</think>", model_response_text, re.IGNORECASE))
|
||||
num_think_close = len(
|
||||
re.findall(r"</think>", model_response_text, re.IGNORECASE)
|
||||
)
|
||||
|
||||
# 2. If counts are not exactly one for each, malformed.
|
||||
if not (num_think_open == 1 and num_think_close == 1):
|
||||
# Optionally, add a debug print statement here if needed for logging
|
||||
# print(f"DEBUG: Malformed think tags. Open: {num_think_open}, Close: {num_think_close}. Response: '{model_response_text[:200]}...'")
|
||||
return 0.0
|
||||
|
||||
# 3. Find the first occurrence of <think> and </think>
|
||||
try:
|
||||
think_open_match = re.search(r"<think>", model_response_text, re.IGNORECASE)
|
||||
think_close_match = re.search(r"</think>", model_response_text, re.IGNORECASE)
|
||||
|
||||
think_close_match = re.search(
|
||||
r"</think>", model_response_text, re.IGNORECASE
|
||||
)
|
||||
|
||||
# These should exist due to the count check, but access .start() and .end() safely
|
||||
idx_think_open = think_open_match.start()
|
||||
idx_think_close_start = think_close_match.start()
|
||||
idx_think_close_end = think_close_match.end()
|
||||
|
||||
except AttributeError:
|
||||
# This case should ideally be caught by the count check, but as a fallback:
|
||||
# print(f"DEBUG: Could not find start/end of think/close tags despite counts. Response: '{model_response_text[:200]}...'")
|
||||
except AttributeError:
|
||||
return 0.0
|
||||
|
||||
# 4. If <think> appears after </think>, malformed.
|
||||
if idx_think_open >= idx_think_close_start:
|
||||
# print(f"DEBUG: <think> tag appears at or after </think> tag. Response: '{model_response_text[:200]}...'")
|
||||
return 0.0
|
||||
|
||||
|
||||
# 5. Extract text_to_verify (content after the first </think>)
|
||||
text_to_verify = model_response_text[idx_think_close_end:].strip()
|
||||
|
||||
# 6. Check if text_to_verify itself contains any further <think> or </think> tags.
|
||||
if re.search(r"<think>", text_to_verify, re.IGNORECASE) or \
|
||||
re.search(r"</think>", text_to_verify, re.IGNORECASE):
|
||||
# print(f"DEBUG: Found <think> or </think> tags in the answer part. Answer part: '{text_to_verify[:200]}...'")
|
||||
if re.search(r"<think>", text_to_verify, re.IGNORECASE) or re.search(
|
||||
r"</think>", text_to_verify, re.IGNORECASE
|
||||
):
|
||||
return 0.0
|
||||
|
||||
# If all checks pass, proceed with verification using text_to_verify
|
||||
|
|
@ -397,6 +398,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
)
|
||||
else:
|
||||
from inspect import signature
|
||||
|
||||
sig = signature(verifier_func)
|
||||
valid_params = [p for p in sig.parameters if p != "text"]
|
||||
filtered_args = {
|
||||
|
|
@ -446,7 +448,9 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
func_name = test_item["func_name"]
|
||||
args_for_verifier = test_item["args"]
|
||||
|
||||
print(f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}...") # DEBUG
|
||||
print(
|
||||
f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}..."
|
||||
) # DEBUG
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
messages.append({"role": "user", "content": instruction_prompt_text})
|
||||
|
|
@ -455,7 +459,9 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
print(f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}...") # DEBUG
|
||||
print(
|
||||
f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}..."
|
||||
) # DEBUG
|
||||
completion = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
n=1,
|
||||
|
|
@ -463,7 +469,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
temperature=0.2, # Temperature for eval, can be 0 for deterministic
|
||||
split="eval",
|
||||
)
|
||||
print(f"DEBUG: Received completion in rollout_and_score_eval.") # DEBUG
|
||||
print("DEBUG: Received completion in rollout_and_score_eval.") # DEBUG
|
||||
|
||||
model_response_text = completion.choices[0].text
|
||||
score_value = await self._get_score_from_verifier(
|
||||
|
|
@ -481,7 +487,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
self.eval_metrics.append(("eval/percent_correct", 0.0))
|
||||
return
|
||||
|
||||
print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG
|
||||
print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG
|
||||
eval_tasks = []
|
||||
for test_item_dict in self.test: # self.test contains dicts after setup
|
||||
eval_tasks.append(self.rollout_and_score_eval(test_item_dict))
|
||||
|
|
@ -501,7 +507,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
) -> Tuple[Optional[ScoredDataGroup], List]:
|
||||
# item = (prompt_messages_tuple, answer_info_dict)
|
||||
# answer_info_dict = {"func_name": ..., "args": ...}
|
||||
print(f"DEBUG: Entering collect_trajectories. Item: {item}") # DEBUG
|
||||
print(f"DEBUG: Entering collect_trajectories. Item: {str(item)}") # DEBUG
|
||||
prompt_messages_list = [dict(msg_fset) for msg_fset in item[0]]
|
||||
answer_info = item[1]
|
||||
|
||||
|
|
@ -509,7 +515,9 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
prompt_messages_list, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
print(f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}...") # DEBUG
|
||||
print(
|
||||
f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}..."
|
||||
) # DEBUG
|
||||
try:
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
|
|
@ -517,11 +525,15 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.8, # Temperature for diverse responses during training rollouts
|
||||
)
|
||||
print(f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories.") # DEBUG
|
||||
print(
|
||||
f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories."
|
||||
) # DEBUG
|
||||
except Exception as e:
|
||||
print(f"ERROR: Exception during self.server.completion in collect_trajectories: {e}") # DEBUG
|
||||
print(
|
||||
f"ERROR: Exception during self.server.completion in collect_trajectories: {e}"
|
||||
) # DEBUG
|
||||
# Depending on the desired behavior, you might want to return None or raise the exception
|
||||
return None, []
|
||||
return None, []
|
||||
|
||||
to_score_list = []
|
||||
for choice in completions.choices:
|
||||
|
|
@ -534,11 +546,15 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
if not to_score_list:
|
||||
return None, []
|
||||
|
||||
print(f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories.") # DEBUG
|
||||
print(
|
||||
f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories."
|
||||
) # DEBUG
|
||||
scored_data = await self.score(to_score_list)
|
||||
to_backlog = [] # Backlog not currently used but part of signature
|
||||
|
||||
print(f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}") # DEBUG
|
||||
print(
|
||||
f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}"
|
||||
) # DEBUG
|
||||
return scored_data, to_backlog
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
|
|
@ -817,7 +833,8 @@ def validate_word_constraint(text: str, N: int, quantifier: str) -> bool:
|
|||
def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool:
|
||||
# Basic sentence splitting, might need more robust NLP for complex cases
|
||||
sentences = re.split(
|
||||
r"(?<![a-zA-Z0-9_]\.[a-zA-Z0-9_]\.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)\s", text.strip()
|
||||
r"(?<![a-zA-Z0-9_]\.[a-zA-Z0-9_]\.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)\s",
|
||||
text.strip(),
|
||||
)
|
||||
# Filter out empty strings that might result from splitting
|
||||
sentences = [s for s in sentences if s.strip()]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue