mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ef9c0c3699
commit
afab28dfa9
37 changed files with 4868 additions and 4052 deletions
|
|
@ -27,6 +27,15 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from eval_helpers import (
|
||||
build_mcqa_fallback_patterns,
|
||||
create_system_content,
|
||||
extract_letter_from_answer_tag,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
save_eval_results,
|
||||
validate_thinking_format,
|
||||
)
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
|
|
@ -36,15 +45,6 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
)
|
||||
from eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
validate_thinking_format,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
create_system_content,
|
||||
save_eval_results,
|
||||
build_mcqa_fallback_patterns,
|
||||
)
|
||||
|
||||
|
||||
class OBQAEvalConfig(BaseEnvConfig):
|
||||
|
|
@ -123,10 +123,10 @@ class OBQAEvalConfig(BaseEnvConfig):
|
|||
class OBQAEvalEnv(BaseEnv):
|
||||
"""
|
||||
OpenBookQA Evaluation Environment for Atropos.
|
||||
|
||||
|
||||
Evaluates models on common sense reasoning with multiple choice questions.
|
||||
"""
|
||||
|
||||
|
||||
name = "obqa_eval"
|
||||
env_config_cls = OBQAEvalConfig
|
||||
|
||||
|
|
@ -140,78 +140,72 @@ class OBQAEvalEnv(BaseEnv):
|
|||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: OBQAEvalConfig = config
|
||||
self.eval_metrics = []
|
||||
|
||||
|
||||
# Regex patterns for thinking mode
|
||||
self._think_pattern = re.compile(r"<think>")
|
||||
self._think_close_pattern = re.compile(r"</think>")
|
||||
self._think_content_pattern = re.compile(r"</think>\s*(.*)", re.DOTALL)
|
||||
self._thinking_extract_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
|
||||
|
||||
# Pre-compile regex for <answer></answer> tag extraction (primary method)
|
||||
self._answer_tag_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
|
||||
|
||||
self._answer_tag_pattern = re.compile(
|
||||
r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
# Build fallback answer extraction patterns
|
||||
self._build_extraction_patterns()
|
||||
|
||||
def _build_extraction_patterns(self):
|
||||
"""
|
||||
Build regex patterns for extracting answer letters from model responses.
|
||||
|
||||
|
||||
Patterns are ordered by priority (lower number = higher priority).
|
||||
Takes the LAST match for answer patterns since models often repeat the final answer.
|
||||
"""
|
||||
# Valid answer letters for OBQA (A-D)
|
||||
letters = "ABCD"
|
||||
letter_pattern = rf"([{letters}]|\([{letters}]\))"
|
||||
|
||||
|
||||
# Priority 0: "final answer is: X" with "I hope" (very specific, highest confidence)
|
||||
self._pattern_final_answer_hope = re.compile(
|
||||
rf"(?i:final\s+answer\s+is)\s*:?\s*{letter_pattern}\.?\s*I\s*hope",
|
||||
re.IGNORECASE
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
# Priority 50: "final answer ... is X" (allows text between)
|
||||
self._pattern_final_answer_is = re.compile(
|
||||
rf"(?i:final\s+answer).{{0,100}}?\s+is\s*:?\s*{letter_pattern}",
|
||||
re.IGNORECASE | re.DOTALL
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
# Priority 75: "the answer is X"
|
||||
self._pattern_the_answer_is = re.compile(
|
||||
rf"(?i:the\s+answer\s+is)\s*:?\s*{letter_pattern}",
|
||||
re.IGNORECASE
|
||||
rf"(?i:the\s+answer\s+is)\s*:?\s*{letter_pattern}", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
# Priority 100: "answer: X" or "Answer: X" (with colon)
|
||||
self._pattern_answer_colon = re.compile(
|
||||
rf"(?i:answer)\s*:\s*.{{0,50}}?{letter_pattern}",
|
||||
re.IGNORECASE | re.DOTALL
|
||||
rf"(?i:answer)\s*:\s*.{{0,50}}?{letter_pattern}", re.IGNORECASE | re.DOTALL
|
||||
)
|
||||
|
||||
|
||||
# Priority 150: "answer X" or "Answer X" (without colon)
|
||||
self._pattern_answer_space = re.compile(
|
||||
rf"(?i:answer)\s+{letter_pattern}",
|
||||
re.IGNORECASE
|
||||
rf"(?i:answer)\s+{letter_pattern}", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
# Priority 200: Response starts with answer letter (with optional punctuation)
|
||||
self._pattern_start = re.compile(
|
||||
rf"^\s*\**{letter_pattern}\**[\s\.\)\:]",
|
||||
re.IGNORECASE
|
||||
rf"^\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
# Priority 210: Letter at start of any line (for multi-line responses)
|
||||
self._pattern_line_start = re.compile(
|
||||
rf"\n\s*\**{letter_pattern}\**[\s\.\)\:]",
|
||||
re.IGNORECASE
|
||||
rf"\n\s*\**{letter_pattern}\**[\s\.\)\:]", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
# Priority 250: Standalone letter with word boundaries
|
||||
self._pattern_standalone = re.compile(
|
||||
rf"\b{letter_pattern}\b",
|
||||
re.IGNORECASE
|
||||
)
|
||||
|
||||
self._pattern_standalone = re.compile(rf"\b{letter_pattern}\b", re.IGNORECASE)
|
||||
|
||||
# Store patterns in priority order
|
||||
self._extraction_patterns = [
|
||||
(0, self._pattern_final_answer_hope, "final_answer_hope"),
|
||||
|
|
@ -233,7 +227,7 @@ class OBQAEvalEnv(BaseEnv):
|
|||
return create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt
|
||||
self.config.custom_system_prompt,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -260,7 +254,7 @@ class OBQAEvalEnv(BaseEnv):
|
|||
eval_max_tokens=0,
|
||||
thinking_mode=True,
|
||||
)
|
||||
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="Hermes-3-Llama-3.1-8B",
|
||||
|
|
@ -270,12 +264,12 @@ class OBQAEvalEnv(BaseEnv):
|
|||
num_requests_for_eval=1024,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
def _format_obqa_prompt(self, item: Dict) -> str:
|
||||
"""Format an OpenBookQA item into a prompt with <answer> tag instruction.
|
||||
|
||||
|
||||
Based on lighteval's openbookqa_prompt but with explicit <answer> tag instruction.
|
||||
"""
|
||||
question = item.get("question_stem", "")
|
||||
|
|
@ -283,16 +277,16 @@ class OBQAEvalEnv(BaseEnv):
|
|||
choice_texts = choices.get("text", [])
|
||||
num_choices = len(choice_texts)
|
||||
valid_letters = ", ".join(ascii_uppercase[:num_choices])
|
||||
|
||||
|
||||
query = "Answer the following multiple choice question about common sense. Think step by step before answering.\n\n"
|
||||
query += f"Provide your final answer within <answer></answer> tags, containing only the letter ({valid_letters}).\n\n"
|
||||
query += "Example format:\n<answer>A</answer>\n\n"
|
||||
query += f"Question: {question}\n"
|
||||
|
||||
|
||||
for i, choice_text in enumerate(choice_texts):
|
||||
letter = ascii_uppercase[i]
|
||||
query += f"{letter}. {choice_text}\n"
|
||||
|
||||
|
||||
return query
|
||||
|
||||
async def setup(self) -> None:
|
||||
|
|
@ -305,7 +299,7 @@ class OBQAEvalEnv(BaseEnv):
|
|||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
if self.config.thinking_mode:
|
||||
print(f" Thinking prompt: {self._get_thinking_prompt()[:100]}...")
|
||||
|
||||
|
||||
try:
|
||||
dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
|
|
@ -314,11 +308,11 @@ class OBQAEvalEnv(BaseEnv):
|
|||
)
|
||||
self.eval_data = list(dataset)
|
||||
print(f" Loaded {len(self.eval_data)} evaluation items")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading OpenBookQA dataset: {e}")
|
||||
raise
|
||||
|
||||
|
||||
self.all_eval_items = self.eval_data
|
||||
self.iter = 0
|
||||
|
||||
|
|
@ -326,13 +320,13 @@ class OBQAEvalEnv(BaseEnv):
|
|||
"""Validate thinking format and extract content after </think> tags."""
|
||||
if not self.config.thinking_mode:
|
||||
return True, response
|
||||
|
||||
|
||||
think_open_count = len(self._think_pattern.findall(response))
|
||||
think_close_count = len(self._think_close_pattern.findall(response))
|
||||
|
||||
|
||||
if think_open_count != 1 or think_close_count != 1:
|
||||
return False, response
|
||||
|
||||
|
||||
match = self._think_content_pattern.search(response)
|
||||
if match:
|
||||
return True, match.group(1).strip()
|
||||
|
|
@ -347,30 +341,27 @@ class OBQAEvalEnv(BaseEnv):
|
|||
return None
|
||||
|
||||
def _extract_answer_letter(
|
||||
self,
|
||||
response: str,
|
||||
num_choices: int,
|
||||
choices: Optional[List[str]] = None
|
||||
self, response: str, num_choices: int, choices: Optional[List[str]] = None
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Extract the answer letter from the model's response.
|
||||
|
||||
|
||||
Primary method: Look for <answer></answer> tags, or match against choice texts.
|
||||
Fallback: Use priority-ordered regex patterns.
|
||||
|
||||
|
||||
Args:
|
||||
response: The model's response string (content after </think> in thinking mode)
|
||||
num_choices: Number of valid choices (determines valid letters)
|
||||
choices: Optional list of choice texts for exact matching
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_letter or None, extraction_method used)
|
||||
"""
|
||||
if not response:
|
||||
return None, "empty_response"
|
||||
|
||||
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
|
||||
# PRIMARY: Try <answer></answer> tags first
|
||||
# Also matches against choice texts if provided
|
||||
letter, method = extract_letter_from_answer_tag(
|
||||
|
|
@ -378,32 +369,46 @@ class OBQAEvalEnv(BaseEnv):
|
|||
)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
|
||||
# FALLBACK: Try each pattern in priority order
|
||||
for priority, pattern, method_name in self._extraction_patterns:
|
||||
matches = pattern.findall(response)
|
||||
if matches:
|
||||
# Get the LAST match for patterns that typically appear at the end
|
||||
match = matches[-1] if method_name in ["final_answer_is", "the_answer_is", "answer_colon", "answer_space"] else matches[0]
|
||||
|
||||
match = (
|
||||
matches[-1]
|
||||
if method_name
|
||||
in [
|
||||
"final_answer_is",
|
||||
"the_answer_is",
|
||||
"answer_colon",
|
||||
"answer_space",
|
||||
]
|
||||
else matches[0]
|
||||
)
|
||||
|
||||
# Clean up the match (remove parentheses if present)
|
||||
if isinstance(match, tuple):
|
||||
match = match[0]
|
||||
letter = match.strip("()").upper()
|
||||
|
||||
|
||||
# Validate it's a valid choice
|
||||
if letter in valid_letters:
|
||||
if self.config.full_debug:
|
||||
print(f" Extracted '{letter}' using fallback method '{method_name}' (priority {priority})")
|
||||
print(
|
||||
f" Extracted '{letter}' using fallback method '{method_name}' (priority {priority})"
|
||||
)
|
||||
return letter, f"fallback_{method_name}"
|
||||
|
||||
|
||||
# Last resort: find any valid letter (take the last one as it's likely the answer)
|
||||
for letter in reversed(list(valid_letters)):
|
||||
if letter in response.upper():
|
||||
if self.config.full_debug:
|
||||
print(f" Extracted '{letter}' using fallback 'last_valid_letter'")
|
||||
print(
|
||||
f" Extracted '{letter}' using fallback 'last_valid_letter'"
|
||||
)
|
||||
return letter, "fallback_last_valid_letter"
|
||||
|
||||
|
||||
return None, "no_match"
|
||||
|
||||
async def get_next_item(self):
|
||||
|
|
@ -428,17 +433,17 @@ class OBQAEvalEnv(BaseEnv):
|
|||
prompt = self._format_obqa_prompt(eval_item)
|
||||
gold_answer = eval_item.get("answerKey", "").strip().upper()
|
||||
choices = eval_item.get("choices", {}).get("text", [])
|
||||
|
||||
|
||||
if not prompt or not gold_answer:
|
||||
return {"result": None, "sample": None}
|
||||
|
||||
|
||||
# Build messages
|
||||
messages = []
|
||||
system_content = self._create_system_content()
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
|
||||
# Get model response
|
||||
model_response = None
|
||||
finish_reason = None
|
||||
|
|
@ -452,132 +457,169 @@ class OBQAEvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
completion_kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
completion = await self.server.chat_completion(**completion_kwargs)
|
||||
|
||||
|
||||
if completion.choices and completion.choices[0].message.content:
|
||||
model_response = completion.choices[0].message.content
|
||||
finish_reason = getattr(completion.choices[0], 'finish_reason', None)
|
||||
|
||||
if len(model_response.strip()) >= self.config.min_response_length:
|
||||
finish_reason = getattr(
|
||||
completion.choices[0], "finish_reason", None
|
||||
)
|
||||
|
||||
if (
|
||||
len(model_response.strip())
|
||||
>= self.config.min_response_length
|
||||
):
|
||||
break
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}")
|
||||
if hasattr(e, 'response'):
|
||||
print(
|
||||
f" API Error (attempt {attempt + 1}/{self.config.max_retries}): {type(e).__name__}: {e}"
|
||||
)
|
||||
if hasattr(e, "response"):
|
||||
try:
|
||||
print(f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}")
|
||||
print(
|
||||
f" Response: {e.response.text[:500] if hasattr(e.response, 'text') else e.response}"
|
||||
)
|
||||
except:
|
||||
pass
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(self.config.retry_delay)
|
||||
else:
|
||||
return {"result": None, "sample": None}
|
||||
|
||||
|
||||
if not model_response:
|
||||
return {"result": None, "sample": None}
|
||||
|
||||
|
||||
# Handle thinking mode
|
||||
thinking_format_valid, response_for_eval = self._validate_thinking_format(model_response)
|
||||
thinking_format_valid, response_for_eval = self._validate_thinking_format(
|
||||
model_response
|
||||
)
|
||||
thinking_content = None
|
||||
if self.config.thinking_mode:
|
||||
thinking_content = self._extract_thinking_content(model_response)
|
||||
|
||||
|
||||
# Extract answer (pass choices for exact text matching)
|
||||
extracted_answer, extraction_method = self._extract_answer_letter(
|
||||
response_for_eval, len(choices), choices=choices
|
||||
)
|
||||
|
||||
|
||||
# Check correctness
|
||||
is_correct = extracted_answer is not None and extracted_answer == gold_answer
|
||||
|
||||
is_correct = (
|
||||
extracted_answer is not None and extracted_answer == gold_answer
|
||||
)
|
||||
|
||||
# Get gold choice text
|
||||
gold_index = ord(gold_answer) - ord('A') if gold_answer in 'ABCD' else -1
|
||||
gold_choice_text = choices[gold_index] if 0 <= gold_index < len(choices) else "N/A"
|
||||
|
||||
gold_index = ord(gold_answer) - ord("A") if gold_answer in "ABCD" else -1
|
||||
gold_choice_text = (
|
||||
choices[gold_index] if 0 <= gold_index < len(choices) else "N/A"
|
||||
)
|
||||
|
||||
sample = {
|
||||
"question": eval_item.get("question_stem", ""),
|
||||
"choices": {ascii_uppercase[i]: c for i, c in enumerate(choices)},
|
||||
"gold_answer": gold_answer,
|
||||
"gold_choice_text": gold_choice_text,
|
||||
"model_response": model_response[:500] if len(model_response) > 500 else model_response,
|
||||
"model_response": (
|
||||
model_response[:500]
|
||||
if len(model_response) > 500
|
||||
else model_response
|
||||
),
|
||||
"extracted_answer": extracted_answer,
|
||||
"extraction_method": extraction_method,
|
||||
"is_correct": is_correct,
|
||||
"finish_reason": finish_reason,
|
||||
"thinking_format_valid": thinking_format_valid,
|
||||
}
|
||||
|
||||
|
||||
if self.config.thinking_mode:
|
||||
sample["thinking_content"] = thinking_content[:300] + "..." if thinking_content and len(thinking_content) > 300 else thinking_content
|
||||
|
||||
sample["thinking_content"] = (
|
||||
thinking_content[:300] + "..."
|
||||
if thinking_content and len(thinking_content) > 300
|
||||
else thinking_content
|
||||
)
|
||||
|
||||
if self.config.full_debug:
|
||||
status = "✓" if is_correct else "✗"
|
||||
print(f" [{status}] Q: {eval_item.get('question_stem', '')[:50]}... | Pred: {extracted_answer}, Gold: {gold_answer}")
|
||||
|
||||
return {
|
||||
"result": {"correct": is_correct},
|
||||
"sample": sample
|
||||
}
|
||||
|
||||
print(
|
||||
f" [{status}] Q: {eval_item.get('question_stem', '')[:50]}... | Pred: {extracted_answer}, Gold: {gold_answer}"
|
||||
)
|
||||
|
||||
return {"result": {"correct": is_correct}, "sample": sample}
|
||||
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f"Error in rollout_and_score_eval: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return {"result": None, "sample": None}
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""Run OpenBookQA evaluation."""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Starting OpenBookQA Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Total questions: {len(self.all_eval_items)}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
try:
|
||||
eval_tasks = [
|
||||
self.rollout_and_score_eval(item) for item in self.all_eval_items
|
||||
]
|
||||
results = await tqdm_asyncio.gather(*eval_tasks, desc="Evaluating OpenBookQA")
|
||||
|
||||
results = await tqdm_asyncio.gather(
|
||||
*eval_tasks, desc="Evaluating OpenBookQA"
|
||||
)
|
||||
|
||||
valid_results = [
|
||||
r for r in results
|
||||
r
|
||||
for r in results
|
||||
if r and r.get("sample") is not None and r.get("result") is not None
|
||||
]
|
||||
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during evaluation: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
# Compute metrics
|
||||
samples = [r["sample"] for r in valid_results]
|
||||
total_count = len(valid_results)
|
||||
|
||||
|
||||
correct_count = sum(1 for s in samples if s.get("is_correct", False))
|
||||
accuracy = correct_count / total_count if total_count > 0 else 0.0
|
||||
|
||||
|
||||
# Answer extraction rate
|
||||
extracted_count = sum(1 for s in samples if s.get("extracted_answer") is not None)
|
||||
extracted_count = sum(
|
||||
1 for s in samples if s.get("extracted_answer") is not None
|
||||
)
|
||||
extraction_rate = extracted_count / total_count if total_count > 0 else 0.0
|
||||
|
||||
|
||||
# Thinking metrics
|
||||
thinking_format_compliant = sum(1 for s in samples if s.get("thinking_format_valid", True))
|
||||
thinking_format_compliance_rate = thinking_format_compliant / total_count if total_count > 0 else 0.0
|
||||
|
||||
thinking_utilization = sum(1 for s in samples if s.get("thinking_content")) if self.config.thinking_mode else 0
|
||||
|
||||
thinking_format_compliant = sum(
|
||||
1 for s in samples if s.get("thinking_format_valid", True)
|
||||
)
|
||||
thinking_format_compliance_rate = (
|
||||
thinking_format_compliant / total_count if total_count > 0 else 0.0
|
||||
)
|
||||
|
||||
thinking_utilization = (
|
||||
sum(1 for s in samples if s.get("thinking_content"))
|
||||
if self.config.thinking_mode
|
||||
else 0
|
||||
)
|
||||
|
||||
eval_metrics = {
|
||||
"eval/accuracy": accuracy,
|
||||
"eval/correct_count": correct_count,
|
||||
|
|
@ -586,13 +628,17 @@ class OBQAEvalEnv(BaseEnv):
|
|||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
"eval/thinking_mode_enabled": 1.0 if self.config.thinking_mode else 0.0,
|
||||
}
|
||||
|
||||
|
||||
if self.config.thinking_mode:
|
||||
eval_metrics["eval/thinking_format_compliance_rate"] = thinking_format_compliance_rate
|
||||
eval_metrics["eval/thinking_utilization_rate"] = thinking_utilization / total_count if total_count > 0 else 0.0
|
||||
|
||||
eval_metrics["eval/thinking_format_compliance_rate"] = (
|
||||
thinking_format_compliance_rate
|
||||
)
|
||||
eval_metrics["eval/thinking_utilization_rate"] = (
|
||||
thinking_utilization / total_count if total_count > 0 else 0.0
|
||||
)
|
||||
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
print(f"OpenBookQA Evaluation Results")
|
||||
|
|
@ -603,7 +649,7 @@ class OBQAEvalEnv(BaseEnv):
|
|||
if self.config.thinking_mode:
|
||||
print(f"Thinking Format Compliance: {thinking_format_compliance_rate:.4f}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
|
|
@ -623,17 +669,18 @@ class OBQAEvalEnv(BaseEnv):
|
|||
"""Log metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
|
||||
for metric_name, metric_value in self.eval_metrics:
|
||||
wandb_metrics[metric_name] = metric_value
|
||||
self.eval_metrics = []
|
||||
|
||||
wandb_metrics["config/thinking_mode"] = 1.0 if self.config.thinking_mode else 0.0
|
||||
|
||||
wandb_metrics["config/thinking_mode"] = (
|
||||
1.0 if self.config.thinking_mode else 0.0
|
||||
)
|
||||
wandb_metrics["config/eval_max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
OBQAEvalEnv.cli()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue