mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ef9c0c3699
commit
afab28dfa9
37 changed files with 4868 additions and 4052 deletions
|
|
@ -25,10 +25,19 @@ import random
|
|||
import re
|
||||
import time
|
||||
from string import ascii_uppercase
|
||||
from typing import Dict, List, Optional, Tuple, Set
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from eval_helpers import (
|
||||
build_mcqa_fallback_patterns,
|
||||
create_system_content,
|
||||
extract_letter_from_answer_tag,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
save_eval_results,
|
||||
validate_thinking_format,
|
||||
)
|
||||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
|
|
@ -38,21 +47,11 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
)
|
||||
from eval_helpers import (
|
||||
extract_letter_from_answer_tag,
|
||||
validate_thinking_format,
|
||||
extract_thinking_content,
|
||||
get_default_thinking_prompt,
|
||||
create_system_content,
|
||||
save_eval_results,
|
||||
build_mcqa_fallback_patterns,
|
||||
)
|
||||
|
||||
|
||||
# All available BBH subsets
|
||||
BBH_SUBSETS = [
|
||||
"causal_judgement",
|
||||
"date_understanding",
|
||||
"date_understanding",
|
||||
"disambiguation_qa",
|
||||
"geometric_shapes",
|
||||
"logical_deduction_five_objects",
|
||||
|
|
@ -75,10 +74,10 @@ BBH_SUBSETS = [
|
|||
def format_bbh_prompt(item: Dict) -> Tuple[str, List[str], int]:
|
||||
"""
|
||||
Format a BBH item into a prompt.
|
||||
|
||||
|
||||
Args:
|
||||
item: The dataset item
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_text, choices_list, gold_index)
|
||||
"""
|
||||
|
|
@ -88,21 +87,25 @@ def format_bbh_prompt(item: Dict) -> Tuple[str, List[str], int]:
|
|||
input_text = item.get("input", "")
|
||||
choice_prefix = item.get("choice_prefix", "\n Choices: ")
|
||||
output_prefix = item.get("example_output_prefix", "\nAnswer: ")
|
||||
|
||||
|
||||
choices = item.get("choices", [])
|
||||
target_idx = item.get("target_idx", 0)
|
||||
|
||||
|
||||
# Build choice text
|
||||
num_choices = len(choices)
|
||||
valid_letters = list(ascii_uppercase[:num_choices])
|
||||
|
||||
|
||||
choice_text = ""
|
||||
for i, (letter, choice) in enumerate(zip(valid_letters, choices)):
|
||||
choice_text += f"\n{letter}. {choice}"
|
||||
|
||||
|
||||
# Add answer tag instruction
|
||||
valid_letters_str = ", ".join(valid_letters[:-1]) + f", or {valid_letters[-1]}" if len(valid_letters) > 1 else valid_letters[0]
|
||||
|
||||
valid_letters_str = (
|
||||
", ".join(valid_letters[:-1]) + f", or {valid_letters[-1]}"
|
||||
if len(valid_letters) > 1
|
||||
else valid_letters[0]
|
||||
)
|
||||
|
||||
query = f"""Answer the following question. Think step by step before answering.
|
||||
|
||||
Provide your final answer within <answer></answer> tags, containing only the letter ({valid_letters_str}).
|
||||
|
|
@ -111,7 +114,7 @@ Example format:
|
|||
<answer>A</answer>
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# Add task-specific content
|
||||
if task_prefix:
|
||||
query += task_prefix
|
||||
|
|
@ -119,75 +122,61 @@ Example format:
|
|||
query += input_text
|
||||
query += choice_prefix
|
||||
query += choice_text
|
||||
|
||||
|
||||
return query, choices, target_idx
|
||||
|
||||
|
||||
class BBHEvalConfig(BaseEnvConfig):
|
||||
"""Configuration for BigBench Hard evaluation environment."""
|
||||
|
||||
|
||||
# Dataset configuration
|
||||
dataset_name: str = Field(
|
||||
default="lighteval/bbh",
|
||||
description="HuggingFace dataset name"
|
||||
default="lighteval/bbh", description="HuggingFace dataset name"
|
||||
)
|
||||
subset: str = Field(
|
||||
default="all",
|
||||
description="Subset to evaluate ('all' for all subsets, or specific subset name)"
|
||||
description="Subset to evaluate ('all' for all subsets, or specific subset name)",
|
||||
)
|
||||
eval_split: str = Field(
|
||||
default="train",
|
||||
description="Split to evaluate on (train is typically the only available split)"
|
||||
description="Split to evaluate on (train is typically the only available split)",
|
||||
)
|
||||
shuffle_seed: int = Field(
|
||||
default=42,
|
||||
description="Random seed for shuffling"
|
||||
)
|
||||
|
||||
shuffle_seed: int = Field(default=42, description="Random seed for shuffling")
|
||||
|
||||
# Generation parameters
|
||||
eval_temperature: float = Field(
|
||||
default=0.6,
|
||||
description="Temperature for evaluation generation"
|
||||
default=0.6, description="Temperature for evaluation generation"
|
||||
)
|
||||
eval_max_tokens: int = Field(
|
||||
default=0,
|
||||
description="Max tokens for evaluation (0 = use model default)"
|
||||
default=0, description="Max tokens for evaluation (0 = use model default)"
|
||||
)
|
||||
|
||||
|
||||
# System prompt configuration
|
||||
custom_system_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional custom system prompt"
|
||||
default=None, description="Optional custom system prompt"
|
||||
)
|
||||
|
||||
|
||||
# Thinking mode configuration
|
||||
thinking_mode: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use thinking mode with <think></think> tags"
|
||||
description="Whether to use thinking mode with <think></think> tags",
|
||||
)
|
||||
custom_thinking_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional custom thinking prompt"
|
||||
default=None, description="Optional custom thinking prompt"
|
||||
)
|
||||
|
||||
|
||||
# Retry and debug configuration
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Maximum retries for failed API calls"
|
||||
default=3, description="Maximum retries for failed API calls"
|
||||
)
|
||||
retry_delay: float = Field(
|
||||
default=1.0,
|
||||
description="Delay between retries in seconds"
|
||||
default=1.0, description="Delay between retries in seconds"
|
||||
)
|
||||
min_response_length: int = Field(
|
||||
default=1,
|
||||
description="Minimum response length to consider valid"
|
||||
default=1, description="Minimum response length to consider valid"
|
||||
)
|
||||
full_debug: bool = Field(
|
||||
default=False,
|
||||
description="Enable full debug output"
|
||||
)
|
||||
|
||||
full_debug: bool = Field(default=False, description="Enable full debug output")
|
||||
|
||||
# Override defaults
|
||||
group_size: int = 1
|
||||
max_num_workers: int = 1024
|
||||
|
|
@ -203,7 +192,7 @@ class BBHEvalConfig(BaseEnvConfig):
|
|||
class BBHEvalEnv(BaseEnv):
|
||||
"""
|
||||
BigBench Hard (BBH) Evaluation Environment.
|
||||
|
||||
|
||||
Evaluates models on challenging reasoning tasks from the BIG-Bench benchmark.
|
||||
All tasks are multiple choice with answer extraction from <answer></answer> tags.
|
||||
"""
|
||||
|
|
@ -229,142 +218,165 @@ class BBHEvalEnv(BaseEnv):
|
|||
async def setup(self) -> None:
|
||||
"""Initialize the environment and load the dataset."""
|
||||
await super().setup()
|
||||
|
||||
|
||||
if not self._dataset_loaded:
|
||||
await self._load_dataset()
|
||||
|
||||
|
||||
print(f"\nBBH Evaluation Setup (Generative Mode):")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Subset: {self.config.subset}")
|
||||
print(f" Evaluation split: {self.config.eval_split}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
if self.config.thinking_mode:
|
||||
thinking_prompt = get_default_thinking_prompt(self.config.custom_thinking_prompt)
|
||||
thinking_prompt = get_default_thinking_prompt(
|
||||
self.config.custom_thinking_prompt
|
||||
)
|
||||
print(f" Thinking prompt: {thinking_prompt[:80]}...")
|
||||
print(f" Loaded {len(self.eval_items)} evaluation items")
|
||||
|
||||
async def _load_dataset(self) -> None:
|
||||
"""Load and process the BBH dataset."""
|
||||
|
||||
|
||||
# Determine which subsets to load
|
||||
if self.config.subset.lower() == "all":
|
||||
subsets_to_load = BBH_SUBSETS
|
||||
else:
|
||||
if self.config.subset not in BBH_SUBSETS:
|
||||
print(f"Warning: Subset '{self.config.subset}' may not exist. Available: {BBH_SUBSETS}")
|
||||
print(
|
||||
f"Warning: Subset '{self.config.subset}' may not exist. Available: {BBH_SUBSETS}"
|
||||
)
|
||||
subsets_to_load = [self.config.subset]
|
||||
|
||||
|
||||
self.eval_items = []
|
||||
|
||||
|
||||
for subset in subsets_to_load:
|
||||
print(f"Loading BBH subset: {subset}...")
|
||||
|
||||
|
||||
try:
|
||||
dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
subset,
|
||||
trust_remote_code=True
|
||||
self.config.dataset_name, subset, trust_remote_code=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" Error loading subset '{subset}': {e}")
|
||||
continue
|
||||
|
||||
|
||||
if self.config.eval_split not in dataset:
|
||||
available_splits = list(dataset.keys())
|
||||
print(f" Split '{self.config.eval_split}' not found for {subset}. Available: {available_splits}")
|
||||
print(
|
||||
f" Split '{self.config.eval_split}' not found for {subset}. Available: {available_splits}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
split_data = dataset[self.config.eval_split]
|
||||
|
||||
|
||||
# Process items
|
||||
for idx, item in enumerate(split_data):
|
||||
# Skip items without choices
|
||||
choices = item.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
self.eval_items.append({
|
||||
"id": f"{subset}_{idx}",
|
||||
"subset": subset,
|
||||
"raw_item": item,
|
||||
"choices": choices,
|
||||
"target_idx": item.get("target_idx", 0),
|
||||
"input": item.get("input", ""),
|
||||
})
|
||||
|
||||
print(f" Loaded {len([i for i in self.eval_items if i['subset'] == subset])} items from {subset}")
|
||||
|
||||
|
||||
self.eval_items.append(
|
||||
{
|
||||
"id": f"{subset}_{idx}",
|
||||
"subset": subset,
|
||||
"raw_item": item,
|
||||
"choices": choices,
|
||||
"target_idx": item.get("target_idx", 0),
|
||||
"input": item.get("input", ""),
|
||||
}
|
||||
)
|
||||
|
||||
print(
|
||||
f" Loaded {len([i for i in self.eval_items if i['subset'] == subset])} items from {subset}"
|
||||
)
|
||||
|
||||
# Shuffle with seed
|
||||
random.seed(self.config.shuffle_seed)
|
||||
random.shuffle(self.eval_items)
|
||||
|
||||
|
||||
self._dataset_loaded = True
|
||||
print(f"Total: Loaded {len(self.eval_items)} evaluation items from {len(subsets_to_load)} subsets")
|
||||
print(
|
||||
f"Total: Loaded {len(self.eval_items)} evaluation items from {len(subsets_to_load)} subsets"
|
||||
)
|
||||
|
||||
def _create_system_content(self) -> str:
|
||||
"""Create system message content based on thinking mode."""
|
||||
return create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt
|
||||
) or ""
|
||||
return (
|
||||
create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt,
|
||||
)
|
||||
or ""
|
||||
)
|
||||
|
||||
def _extract_answer(
|
||||
self,
|
||||
response: str,
|
||||
num_choices: int,
|
||||
choices: List[str],
|
||||
debug: bool = False
|
||||
self, response: str, num_choices: int, choices: List[str], debug: bool = False
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Extract the letter answer from the response.
|
||||
|
||||
|
||||
Args:
|
||||
response: The model's response (content after </think> in thinking mode)
|
||||
num_choices: Number of valid choices
|
||||
choices: List of choice texts
|
||||
debug: Whether to print debug information
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_letter or None, extraction_method)
|
||||
"""
|
||||
if not response:
|
||||
return None, "empty_response"
|
||||
|
||||
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
|
||||
# PRIMARY: Try <answer></answer> tags
|
||||
letter, method = extract_letter_from_answer_tag(
|
||||
response, valid_letters, debug=debug, choices=choices
|
||||
)
|
||||
if letter:
|
||||
return letter, method
|
||||
|
||||
|
||||
# FALLBACK: Use regex patterns
|
||||
fallback_patterns = build_mcqa_fallback_patterns(num_choices)
|
||||
|
||||
|
||||
for priority, pattern, method_name in fallback_patterns:
|
||||
matches = pattern.findall(response)
|
||||
if matches:
|
||||
# Get the last match for answer patterns
|
||||
match = matches[-1] if method_name in ["final_answer_is", "the_answer_is", "answer_colon", "answer_space"] else matches[0]
|
||||
|
||||
match = (
|
||||
matches[-1]
|
||||
if method_name
|
||||
in [
|
||||
"final_answer_is",
|
||||
"the_answer_is",
|
||||
"answer_colon",
|
||||
"answer_space",
|
||||
]
|
||||
else matches[0]
|
||||
)
|
||||
|
||||
if isinstance(match, tuple):
|
||||
match = match[0]
|
||||
extracted = match.strip("()").upper()
|
||||
|
||||
|
||||
if extracted in valid_letters:
|
||||
if debug:
|
||||
print(f" Extracted '{extracted}' using fallback '{method_name}'")
|
||||
print(
|
||||
f" Extracted '{extracted}' using fallback '{method_name}'"
|
||||
)
|
||||
return extracted, f"fallback_{method_name}"
|
||||
|
||||
|
||||
# Last resort: find any valid letter (take the last one)
|
||||
for letter in reversed(list(valid_letters)):
|
||||
if letter in response.upper():
|
||||
if debug:
|
||||
print(f" Extracted '{letter}' using fallback 'last_valid_letter'")
|
||||
print(
|
||||
f" Extracted '{letter}' using fallback 'last_valid_letter'"
|
||||
)
|
||||
return letter, "fallback_last_valid_letter"
|
||||
|
||||
|
||||
return None, "no_match"
|
||||
|
||||
async def rollout_and_score_eval(
|
||||
|
|
@ -373,19 +385,21 @@ class BBHEvalEnv(BaseEnv):
|
|||
server: APIServerConfig,
|
||||
) -> Optional[Dict]:
|
||||
"""Run evaluation on a single item and return the result."""
|
||||
|
||||
|
||||
# Format the prompt
|
||||
prompt, choices, target_idx = format_bbh_prompt(item["raw_item"])
|
||||
num_choices = len(choices)
|
||||
gold_letter = ascii_uppercase[target_idx] if 0 <= target_idx < num_choices else None
|
||||
|
||||
gold_letter = (
|
||||
ascii_uppercase[target_idx] if 0 <= target_idx < num_choices else None
|
||||
)
|
||||
|
||||
system_content = self._create_system_content()
|
||||
|
||||
|
||||
messages = []
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
|
||||
# Build API call parameters
|
||||
kwargs = {
|
||||
"model": server.model_name,
|
||||
|
|
@ -394,53 +408,59 @@ class BBHEvalEnv(BaseEnv):
|
|||
}
|
||||
if self.config.eval_max_tokens > 0:
|
||||
kwargs["max_tokens"] = self.config.eval_max_tokens
|
||||
|
||||
|
||||
response_text = ""
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
response = await self.server.chat_completion(**kwargs)
|
||||
response_text = response.choices[0].message.content or ""
|
||||
|
||||
|
||||
if len(response_text) >= self.config.min_response_length:
|
||||
break
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if self.config.full_debug:
|
||||
print(f" API error (attempt {attempt + 1}): {e}")
|
||||
if attempt < self.config.max_retries - 1:
|
||||
await asyncio.sleep(self.config.retry_delay)
|
||||
continue
|
||||
|
||||
|
||||
if not response_text:
|
||||
return None
|
||||
|
||||
|
||||
# Validate thinking format and extract content after </think>
|
||||
is_valid_format, content_for_extraction = validate_thinking_format(
|
||||
response_text,
|
||||
self.config.thinking_mode
|
||||
response_text, self.config.thinking_mode
|
||||
)
|
||||
|
||||
|
||||
# Extract thinking content if present
|
||||
thinking_content = extract_thinking_content(response_text) if self.config.thinking_mode else None
|
||||
|
||||
thinking_content = (
|
||||
extract_thinking_content(response_text)
|
||||
if self.config.thinking_mode
|
||||
else None
|
||||
)
|
||||
|
||||
# Extract answer
|
||||
extracted_answer, extraction_method = self._extract_answer(
|
||||
content_for_extraction,
|
||||
num_choices,
|
||||
choices,
|
||||
debug=self.config.full_debug
|
||||
content_for_extraction, num_choices, choices, debug=self.config.full_debug
|
||||
)
|
||||
|
||||
|
||||
# Score
|
||||
is_correct = extracted_answer == gold_letter if extracted_answer and gold_letter else False
|
||||
|
||||
is_correct = (
|
||||
extracted_answer == gold_letter
|
||||
if extracted_answer and gold_letter
|
||||
else False
|
||||
)
|
||||
|
||||
if self.config.full_debug:
|
||||
print(f"\n--- Item: {item['id']} ---")
|
||||
print(f"Subset: {item['subset']}")
|
||||
print(f"Input: {item['input'][:100]}...")
|
||||
print(f"Gold: {gold_letter}, Extracted: {extracted_answer} (method: {extraction_method})")
|
||||
print(
|
||||
f"Gold: {gold_letter}, Extracted: {extracted_answer} (method: {extraction_method})"
|
||||
)
|
||||
print(f"Correct: {is_correct}")
|
||||
|
||||
|
||||
return {
|
||||
"item_id": item["id"],
|
||||
"subset": item["subset"],
|
||||
|
|
@ -465,31 +485,28 @@ class BBHEvalEnv(BaseEnv):
|
|||
print(f" Total questions: {len(self.eval_items)}")
|
||||
print(f" Thinking mode: {self.config.thinking_mode}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
# Create evaluation tasks
|
||||
async def eval_task(item):
|
||||
return await self.rollout_and_score_eval(item, self.server_configs[0])
|
||||
|
||||
|
||||
tasks = [eval_task(item) for item in self.eval_items]
|
||||
|
||||
|
||||
# Run with progress bar
|
||||
results = await tqdm_asyncio.gather(
|
||||
*tasks,
|
||||
desc="Evaluating BBH"
|
||||
)
|
||||
|
||||
results = await tqdm_asyncio.gather(*tasks, desc="Evaluating BBH")
|
||||
|
||||
# Filter out failed results
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return {"error": "No valid results", "accuracy": 0.0}
|
||||
|
||||
|
||||
# Calculate overall metrics
|
||||
total = len(valid_results)
|
||||
correct = sum(1 for r in valid_results if r["is_correct"])
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
|
||||
# Calculate per-subset metrics
|
||||
subset_metrics = {}
|
||||
for r in valid_results:
|
||||
|
|
@ -499,22 +516,24 @@ class BBHEvalEnv(BaseEnv):
|
|||
subset_metrics[subset]["total"] += 1
|
||||
if r["is_correct"]:
|
||||
subset_metrics[subset]["correct"] += 1
|
||||
|
||||
|
||||
for subset in subset_metrics:
|
||||
s_total = subset_metrics[subset]["total"]
|
||||
s_correct = subset_metrics[subset]["correct"]
|
||||
subset_metrics[subset]["accuracy"] = s_correct / s_total if s_total > 0 else 0.0
|
||||
|
||||
subset_metrics[subset]["accuracy"] = (
|
||||
s_correct / s_total if s_total > 0 else 0.0
|
||||
)
|
||||
|
||||
# Format compliance and thinking utilization
|
||||
format_valid = sum(1 for r in valid_results if r.get("format_valid", True))
|
||||
has_thinking = sum(1 for r in valid_results if r.get("has_thinking", False))
|
||||
|
||||
|
||||
# Extraction method breakdown
|
||||
method_counts = {}
|
||||
for r in valid_results:
|
||||
method = r.get("extraction_method", "unknown")
|
||||
method_counts[method] = method_counts.get(method, 0) + 1
|
||||
|
||||
|
||||
metrics = {
|
||||
"accuracy": accuracy,
|
||||
"total_evaluated": total,
|
||||
|
|
@ -525,7 +544,7 @@ class BBHEvalEnv(BaseEnv):
|
|||
"subset_metrics": subset_metrics,
|
||||
"extraction_methods": method_counts,
|
||||
}
|
||||
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("BBH Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
|
|
@ -535,14 +554,18 @@ class BBHEvalEnv(BaseEnv):
|
|||
if self.config.thinking_mode:
|
||||
print(f" Thinking Utilization: {has_thinking / total:.2%}")
|
||||
print(f"\n Per-Subset Breakdown:")
|
||||
for subset, data in sorted(subset_metrics.items(), key=lambda x: -x[1]["accuracy"]):
|
||||
print(f" {subset}: {data['accuracy']:.2%} ({data['correct']}/{data['total']})")
|
||||
for subset, data in sorted(
|
||||
subset_metrics.items(), key=lambda x: -x[1]["accuracy"]
|
||||
):
|
||||
print(
|
||||
f" {subset}: {data['accuracy']:.2%} ({data['correct']}/{data['total']})"
|
||||
)
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
# Save results
|
||||
if self.config.data_dir_to_save_evals:
|
||||
self._save_results(metrics, valid_results)
|
||||
|
||||
|
||||
return metrics
|
||||
|
||||
def _save_results(self, metrics: Dict, results: List[Dict]) -> None:
|
||||
|
|
@ -553,20 +576,22 @@ class BBHEvalEnv(BaseEnv):
|
|||
"""Log metrics to Weights & Biases."""
|
||||
if not self.config.use_wandb:
|
||||
return
|
||||
|
||||
|
||||
log_dict = {
|
||||
"bbh/accuracy": metrics.get("accuracy", 0),
|
||||
"bbh/total_evaluated": metrics.get("total_evaluated", 0),
|
||||
"bbh/num_subsets": metrics.get("num_subsets", 0),
|
||||
"bbh/format_compliance_rate": metrics.get("format_compliance_rate", 0),
|
||||
"bbh/thinking_utilization_rate": metrics.get("thinking_utilization_rate", 0),
|
||||
"bbh/thinking_utilization_rate": metrics.get(
|
||||
"thinking_utilization_rate", 0
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Log per-subset accuracies
|
||||
for subset, data in metrics.get("subset_metrics", {}).items():
|
||||
safe_name = subset.replace(" ", "_")[:40]
|
||||
log_dict[f"bbh/accuracy_{safe_name}"] = data.get("accuracy", 0)
|
||||
|
||||
|
||||
wandb.log(log_dict, step=step)
|
||||
|
||||
# Required abstract method implementations
|
||||
|
|
@ -585,4 +610,3 @@ class BBHEvalEnv(BaseEnv):
|
|||
|
||||
if __name__ == "__main__":
|
||||
BBHEvalEnv.cli()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue