mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add stricter enforcement of think tags
This commit is contained in:
parent
6ae0703ad6
commit
daa6f0ff18
1 changed files with 55 additions and 23 deletions
|
|
@ -61,13 +61,13 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
# Configuration for the Instruction Following Environment
|
||||
env_config = IFConfig( # Changed BaseEnvConfig to IFConfig
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=16,
|
||||
group_size=32,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=500,
|
||||
batch_size=1024,
|
||||
steps_per_eval=20,
|
||||
max_token_length=1024 * 16,
|
||||
max_token_length=1024 * 15,
|
||||
inference_weight=1.0,
|
||||
wandb_name="instruction_following_rlvr_ifeval", # Specific WandB project name
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
|
|
@ -82,7 +82,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_max_requests_at_once=16,
|
||||
num_max_requests_at_once=32,
|
||||
num_requests_for_eval=256,
|
||||
)
|
||||
]
|
||||
|
|
@ -333,7 +333,50 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
async def _get_score_from_verifier(
|
||||
self, model_response_text: str, func_name: str, args: Dict
|
||||
) -> float:
|
||||
"""Helper to call verifier function and get a numerical score."""
|
||||
"""Helper to call verifier function and get a numerical score.
|
||||
Also enforces strict <think>...</think> formatting.
|
||||
"""
|
||||
|
||||
# 1. Count <think> and </think> tags
|
||||
num_think_open = len(re.findall(r"<think>", model_response_text, re.IGNORECASE))
|
||||
num_think_close = len(re.findall(r"</think>", model_response_text, re.IGNORECASE))
|
||||
|
||||
# 2. If counts are not exactly one for each, malformed.
|
||||
if not (num_think_open == 1 and num_think_close == 1):
|
||||
# Optionally, add a debug print statement here if needed for logging
|
||||
# print(f"DEBUG: Malformed think tags. Open: {num_think_open}, Close: {num_think_close}. Response: '{model_response_text[:200]}...'")
|
||||
return 0.0
|
||||
|
||||
# 3. Find the first occurrence of <think> and </think>
|
||||
try:
|
||||
think_open_match = re.search(r"<think>", model_response_text, re.IGNORECASE)
|
||||
think_close_match = re.search(r"</think>", model_response_text, re.IGNORECASE)
|
||||
|
||||
# These should exist due to the count check, but access .start() and .end() safely
|
||||
idx_think_open = think_open_match.start()
|
||||
idx_think_close_start = think_close_match.start()
|
||||
idx_think_close_end = think_close_match.end()
|
||||
|
||||
except AttributeError:
|
||||
# This case should ideally be caught by the count check, but as a fallback:
|
||||
# print(f"DEBUG: Could not find start/end of think/close tags despite counts. Response: '{model_response_text[:200]}...'")
|
||||
return 0.0
|
||||
|
||||
# 4. If <think> appears after </think>, malformed.
|
||||
if idx_think_open >= idx_think_close_start:
|
||||
# print(f"DEBUG: <think> tag appears at or after </think> tag. Response: '{model_response_text[:200]}...'")
|
||||
return 0.0
|
||||
|
||||
# 5. Extract text_to_verify (content after the first </think>)
|
||||
text_to_verify = model_response_text[idx_think_close_end:].strip()
|
||||
|
||||
# 6. Check if text_to_verify itself contains any further <think> or </think> tags.
|
||||
if re.search(r"<think>", text_to_verify, re.IGNORECASE) or \
|
||||
re.search(r"</think>", text_to_verify, re.IGNORECASE):
|
||||
# print(f"DEBUG: Found <think> or </think> tags in the answer part. Answer part: '{text_to_verify[:200]}...'")
|
||||
return 0.0
|
||||
|
||||
# If all checks pass, proceed with verification using text_to_verify
|
||||
if func_name not in IF_FUNCTIONS_MAP:
|
||||
print(
|
||||
f"Warning: Verifier function '{func_name}' not found in IF_FUNCTIONS_MAP."
|
||||
|
|
@ -345,58 +388,47 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
raw_score = None
|
||||
try:
|
||||
if func_name == "validate_placeholders":
|
||||
# validate_placeholders expects 'text' and 'N', returns (bool, list)
|
||||
raw_score = verifier_func(model_response_text, N=args.get("N"))
|
||||
raw_score = verifier_func(text_to_verify, N=args.get("N"))
|
||||
elif func_name == "verify_bullet_points":
|
||||
# verify_bullet_points expects 'text' and 'N', returns bool (or was (bool,str) in one doc)
|
||||
# Assuming it now consistently returns bool as per our integrated version
|
||||
raw_score = verifier_func(model_response_text, N=args.get("N"))
|
||||
# Special handler for repeat-prompt, which only wants (text, original_prompt)
|
||||
raw_score = verifier_func(text_to_verify, N=args.get("N"))
|
||||
elif func_name == "validate_repeat_prompt":
|
||||
raw_score = verifier_func(
|
||||
model_response_text, args.get("original_prompt", "")
|
||||
text_to_verify, args.get("original_prompt", "")
|
||||
)
|
||||
# Generic handler: only pass kwargs the function actually declares
|
||||
else:
|
||||
from inspect import signature
|
||||
|
||||
sig = signature(verifier_func)
|
||||
# drop the 'text' parameter and keep only declared kwargs
|
||||
valid_params = [p for p in sig.parameters if p != "text"]
|
||||
filtered_args = {
|
||||
k: args[k]
|
||||
for k in valid_params
|
||||
if k in args and args[k] is not None
|
||||
}
|
||||
raw_score = verifier_func(model_response_text, **filtered_args)
|
||||
raw_score = verifier_func(text_to_verify, **filtered_args)
|
||||
|
||||
except LangDetectException: # Specifically catch for language detection issues
|
||||
except LangDetectException:
|
||||
print(
|
||||
f"Warning: langdetect failed for func_name '{func_name}'. Scoring as incorrect."
|
||||
)
|
||||
return 0.0
|
||||
except ImportError as e:
|
||||
# This might happen if a function tries a lazy import that fails (not langdetect now)
|
||||
print(
|
||||
f"Warning: ImportError during verifier function '{func_name}': {e}. Check dependencies."
|
||||
)
|
||||
return 0.0
|
||||
except TypeError as e:
|
||||
print(
|
||||
f"TypeError calling {func_name} with args {args}: {e}. Text: '{model_response_text[:100]}...'"
|
||||
f"TypeError calling {func_name} with args {args}: {e}. Text: '{text_to_verify[:100]}...'"
|
||||
)
|
||||
return 0.0
|
||||
except Exception as e: # Catch any other unexpected error from a verifier
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Unexpected error in verifier function '{func_name}' with args {args}: {e}"
|
||||
)
|
||||
return 0.0
|
||||
|
||||
# Convert boolean or tuple[boolean, ...] to float score
|
||||
if isinstance(raw_score, tuple):
|
||||
score_value = float(
|
||||
raw_score[0]
|
||||
) # Assuming the first element is the boolean score
|
||||
score_value = float(raw_score[0])
|
||||
elif isinstance(raw_score, bool):
|
||||
score_value = float(raw_score)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue