add stricter enforcement of think tags

This commit is contained in:
teknium1 2025-05-16 13:18:20 -07:00
parent 6ae0703ad6
commit daa6f0ff18

View file

@ -61,13 +61,13 @@ class InstructionFollowingEnv(BaseEnv):
# Configuration for the Instruction Following Environment
env_config = IFConfig( # Changed BaseEnvConfig to IFConfig
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
group_size=32,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=500,
batch_size=1024,
steps_per_eval=20,
max_token_length=1024 * 16,
max_token_length=1024 * 15,
inference_weight=1.0,
wandb_name="instruction_following_rlvr_ifeval", # Specific WandB project name
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
@ -82,7 +82,7 @@ class InstructionFollowingEnv(BaseEnv):
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_max_requests_at_once=16,
num_max_requests_at_once=32,
num_requests_for_eval=256,
)
]
@ -333,7 +333,50 @@ class InstructionFollowingEnv(BaseEnv):
async def _get_score_from_verifier(
self, model_response_text: str, func_name: str, args: Dict
) -> float:
"""Helper to call verifier function and get a numerical score."""
"""Helper to call verifier function and get a numerical score.
Also enforces strict <think>...</think> formatting.
"""
# 1. Count <think> and </think> tags
num_think_open = len(re.findall(r"<think>", model_response_text, re.IGNORECASE))
num_think_close = len(re.findall(r"</think>", model_response_text, re.IGNORECASE))
# 2. If counts are not exactly one for each, malformed.
if not (num_think_open == 1 and num_think_close == 1):
# Optionally, add a debug print statement here if needed for logging
# print(f"DEBUG: Malformed think tags. Open: {num_think_open}, Close: {num_think_close}. Response: '{model_response_text[:200]}...'")
return 0.0
# 3. Find the first occurrence of <think> and </think>
try:
think_open_match = re.search(r"<think>", model_response_text, re.IGNORECASE)
think_close_match = re.search(r"</think>", model_response_text, re.IGNORECASE)
# These should exist due to the count check, but access .start() and .end() safely
idx_think_open = think_open_match.start()
idx_think_close_start = think_close_match.start()
idx_think_close_end = think_close_match.end()
except AttributeError:
# This case should ideally be caught by the count check, but as a fallback:
# print(f"DEBUG: Could not find start/end of think/close tags despite counts. Response: '{model_response_text[:200]}...'")
return 0.0
# 4. If <think> appears after </think>, malformed.
if idx_think_open >= idx_think_close_start:
# print(f"DEBUG: <think> tag appears at or after </think> tag. Response: '{model_response_text[:200]}...'")
return 0.0
# 5. Extract text_to_verify (content after the first </think>)
text_to_verify = model_response_text[idx_think_close_end:].strip()
# 6. Check if text_to_verify itself contains any further <think> or </think> tags.
if re.search(r"<think>", text_to_verify, re.IGNORECASE) or \
re.search(r"</think>", text_to_verify, re.IGNORECASE):
# print(f"DEBUG: Found <think> or </think> tags in the answer part. Answer part: '{text_to_verify[:200]}...'")
return 0.0
# If all checks pass, proceed with verification using text_to_verify
if func_name not in IF_FUNCTIONS_MAP:
print(
f"Warning: Verifier function '{func_name}' not found in IF_FUNCTIONS_MAP."
@ -345,58 +388,47 @@ class InstructionFollowingEnv(BaseEnv):
raw_score = None
try:
if func_name == "validate_placeholders":
# validate_placeholders expects 'text' and 'N', returns (bool, list)
raw_score = verifier_func(model_response_text, N=args.get("N"))
raw_score = verifier_func(text_to_verify, N=args.get("N"))
elif func_name == "verify_bullet_points":
# verify_bullet_points expects 'text' and 'N', returns bool (or was (bool,str) in one doc)
# Assuming it now consistently returns bool as per our integrated version
raw_score = verifier_func(model_response_text, N=args.get("N"))
# Special handler for repeat-prompt, which only wants (text, original_prompt)
raw_score = verifier_func(text_to_verify, N=args.get("N"))
elif func_name == "validate_repeat_prompt":
raw_score = verifier_func(
model_response_text, args.get("original_prompt", "")
text_to_verify, args.get("original_prompt", "")
)
# Generic handler: only pass kwargs the function actually declares
else:
from inspect import signature
sig = signature(verifier_func)
# drop the 'text' parameter and keep only declared kwargs
valid_params = [p for p in sig.parameters if p != "text"]
filtered_args = {
k: args[k]
for k in valid_params
if k in args and args[k] is not None
}
raw_score = verifier_func(model_response_text, **filtered_args)
raw_score = verifier_func(text_to_verify, **filtered_args)
except LangDetectException: # Specifically catch for language detection issues
except LangDetectException:
print(
f"Warning: langdetect failed for func_name '{func_name}'. Scoring as incorrect."
)
return 0.0
except ImportError as e:
# This might happen if a function tries a lazy import that fails (not langdetect now)
print(
f"Warning: ImportError during verifier function '{func_name}': {e}. Check dependencies."
)
return 0.0
except TypeError as e:
print(
f"TypeError calling {func_name} with args {args}: {e}. Text: '{model_response_text[:100]}...'"
f"TypeError calling {func_name} with args {args}: {e}. Text: '{text_to_verify[:100]}...'"
)
return 0.0
except Exception as e: # Catch any other unexpected error from a verifier
except Exception as e:
print(
f"Unexpected error in verifier function '{func_name}' with args {args}: {e}"
)
return 0.0
# Convert boolean or tuple[boolean, ...] to float score
if isinstance(raw_score, tuple):
score_value = float(
raw_score[0]
) # Assuming the first element is the boolean score
score_value = float(raw_score[0])
elif isinstance(raw_score, bool):
score_value = float(raw_score)
else: